You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1502 lines
56 KiB

  1. # Copyright 2019-2022 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
  16. import attr
  17. from synapse.api.constants import EventContentFields, RelationTypes
  18. from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
  19. from synapse.events import make_event_from_dict
  20. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  21. from synapse.storage.database import (
  22. DatabasePool,
  23. LoggingDatabaseConnection,
  24. LoggingTransaction,
  25. make_tuple_comparison_clause,
  26. )
  27. from synapse.storage.databases.main.events import PersistEventsStore
  28. from synapse.storage.types import Cursor
  29. from synapse.types import JsonDict, StrCollection
  30. if TYPE_CHECKING:
  31. from synapse.server import HomeServer
  32. logger = logging.getLogger(__name__)
  33. _REPLACE_STREAM_ORDERING_SQL_COMMANDS = (
  34. # there should be no leftover rows without a stream_ordering2, but just in case...
  35. "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL",
  36. # now we can drop the rule and switch the columns
  37. "DROP RULE populate_stream_ordering2 ON events",
  38. "ALTER TABLE events DROP COLUMN stream_ordering",
  39. "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering",
  40. # ... and finally, rename the indexes into place for consistency with sqlite
  41. "ALTER INDEX event_contains_url_index2 RENAME TO event_contains_url_index",
  42. "ALTER INDEX events_order_room2 RENAME TO events_order_room",
  43. "ALTER INDEX events_room_stream2 RENAME TO events_room_stream",
  44. "ALTER INDEX events_ts2 RENAME TO events_ts",
  45. )
  46. class _BackgroundUpdates:
  47. EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
  48. EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
  49. DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
  50. POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
  51. INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
  52. INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url"
  53. INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order"
  54. INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream"
  55. INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
  56. REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
  57. EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
  58. EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
  59. EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
  60. EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
  61. @attr.s(slots=True, frozen=True, auto_attribs=True)
  62. class _CalculateChainCover:
  63. """Return value for _calculate_chain_cover_txn."""
  64. # The last room_id/depth/stream processed.
  65. room_id: str
  66. depth: int
  67. stream: int
  68. # Number of rows processed
  69. processed_count: int
  70. # Map from room_id to last depth/stream processed for each room that we have
  71. # processed all events for (i.e. the rooms we can flip the
  72. # `has_auth_chain_index` for)
  73. finished_room_map: Dict[str, Tuple[int, int]]
  74. class EventsBackgroundUpdatesStore(SQLBaseStore):
  75. def __init__(
  76. self,
  77. database: DatabasePool,
  78. db_conn: LoggingDatabaseConnection,
  79. hs: "HomeServer",
  80. ):
  81. super().__init__(database, db_conn, hs)
  82. self.db_pool.updates.register_background_update_handler(
  83. _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME,
  84. self._background_reindex_origin_server_ts,
  85. )
  86. self.db_pool.updates.register_background_update_handler(
  87. _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
  88. self._background_reindex_fields_sender,
  89. )
  90. self.db_pool.updates.register_background_index_update(
  91. "event_contains_url_index",
  92. index_name="event_contains_url_index",
  93. table="events",
  94. columns=["room_id", "topological_ordering", "stream_ordering"],
  95. where_clause="contains_url = true AND outlier = false",
  96. )
  97. # an event_id index on event_search is useful for the purge_history
  98. # api. Plus it means we get to enforce some integrity with a UNIQUE
  99. # clause
  100. self.db_pool.updates.register_background_index_update(
  101. "event_search_event_id_idx",
  102. index_name="event_search_event_id_idx",
  103. table="event_search",
  104. columns=["event_id"],
  105. unique=True,
  106. psql_only=True,
  107. )
  108. self.db_pool.updates.register_background_update_handler(
  109. _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES,
  110. self._cleanup_extremities_bg_update,
  111. )
  112. self.db_pool.updates.register_background_update_handler(
  113. "redactions_received_ts", self._redactions_received_ts
  114. )
  115. # This index gets deleted in `event_fix_redactions_bytes` update
  116. self.db_pool.updates.register_background_index_update(
  117. "event_fix_redactions_bytes_create_index",
  118. index_name="redactions_censored_redacts",
  119. table="redactions",
  120. columns=["redacts"],
  121. where_clause="have_censored",
  122. )
  123. self.db_pool.updates.register_background_update_handler(
  124. "event_fix_redactions_bytes", self._event_fix_redactions_bytes
  125. )
  126. self.db_pool.updates.register_background_update_handler(
  127. "event_store_labels", self._event_store_labels
  128. )
  129. self.db_pool.updates.register_background_index_update(
  130. "redactions_have_censored_ts_idx",
  131. index_name="redactions_have_censored_ts",
  132. table="redactions",
  133. columns=["received_ts"],
  134. where_clause="NOT have_censored",
  135. )
  136. self.db_pool.updates.register_background_index_update(
  137. "users_have_local_media",
  138. index_name="users_have_local_media",
  139. table="local_media_repository",
  140. columns=["user_id", "created_ts"],
  141. )
  142. self.db_pool.updates.register_background_update_handler(
  143. "rejected_events_metadata",
  144. self._rejected_events_metadata,
  145. )
  146. self.db_pool.updates.register_background_update_handler(
  147. "chain_cover",
  148. self._chain_cover_index,
  149. )
  150. self.db_pool.updates.register_background_update_handler(
  151. "purged_chain_cover",
  152. self._purged_chain_cover_index,
  153. )
  154. self.db_pool.updates.register_background_update_handler(
  155. "event_arbitrary_relations",
  156. self._event_arbitrary_relations,
  157. )
  158. ################################################################################
  159. # bg updates for replacing stream_ordering with a BIGINT
  160. # (these only run on postgres.)
  161. self.db_pool.updates.register_background_update_handler(
  162. _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
  163. self._background_populate_stream_ordering2,
  164. )
  165. # CREATE UNIQUE INDEX events_stream_ordering ON events(stream_ordering2);
  166. self.db_pool.updates.register_background_index_update(
  167. _BackgroundUpdates.INDEX_STREAM_ORDERING2,
  168. index_name="events_stream_ordering",
  169. table="events",
  170. columns=["stream_ordering2"],
  171. unique=True,
  172. )
  173. # CREATE INDEX event_contains_url_index ON events(room_id, topological_ordering, stream_ordering) WHERE contains_url = true AND outlier = false;
  174. self.db_pool.updates.register_background_index_update(
  175. _BackgroundUpdates.INDEX_STREAM_ORDERING2_CONTAINS_URL,
  176. index_name="event_contains_url_index2",
  177. table="events",
  178. columns=["room_id", "topological_ordering", "stream_ordering2"],
  179. where_clause="contains_url = true AND outlier = false",
  180. )
  181. # CREATE INDEX events_order_room ON events(room_id, topological_ordering, stream_ordering);
  182. self.db_pool.updates.register_background_index_update(
  183. _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_ORDER,
  184. index_name="events_order_room2",
  185. table="events",
  186. columns=["room_id", "topological_ordering", "stream_ordering2"],
  187. )
  188. # CREATE INDEX events_room_stream ON events(room_id, stream_ordering);
  189. self.db_pool.updates.register_background_index_update(
  190. _BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_STREAM,
  191. index_name="events_room_stream2",
  192. table="events",
  193. columns=["room_id", "stream_ordering2"],
  194. )
  195. # CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);
  196. self.db_pool.updates.register_background_index_update(
  197. _BackgroundUpdates.INDEX_STREAM_ORDERING2_TS,
  198. index_name="events_ts2",
  199. table="events",
  200. columns=["origin_server_ts", "stream_ordering2"],
  201. )
  202. self.db_pool.updates.register_background_update_handler(
  203. _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN,
  204. self._background_replace_stream_ordering_column,
  205. )
  206. ################################################################################
  207. self.db_pool.updates.register_background_update_handler(
  208. _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
  209. self._background_drop_invalid_event_edges_rows,
  210. )
  211. self.db_pool.updates.register_background_index_update(
  212. _BackgroundUpdates.EVENT_EDGES_REPLACE_INDEX,
  213. index_name="event_edges_event_id_prev_event_id_idx",
  214. table="event_edges",
  215. columns=["event_id", "prev_event_id"],
  216. unique=True,
  217. # the old index which just covered event_id is now redundant.
  218. replaces_index="ev_edges_id",
  219. )
  220. self.db_pool.updates.register_background_update_handler(
  221. _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
  222. self._background_events_populate_state_key_rejections,
  223. )
  224. # Add an index that would be useful for jumping to date using
  225. # get_event_id_for_timestamp.
  226. self.db_pool.updates.register_background_index_update(
  227. _BackgroundUpdates.EVENTS_JUMP_TO_DATE_INDEX,
  228. index_name="events_jump_to_date_idx",
  229. table="events",
  230. columns=["room_id", "origin_server_ts"],
  231. where_clause="NOT outlier",
  232. )
  233. async def _background_reindex_fields_sender(
  234. self, progress: JsonDict, batch_size: int
  235. ) -> int:
  236. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  237. max_stream_id = progress["max_stream_id_exclusive"]
  238. rows_inserted = progress.get("rows_inserted", 0)
  239. def reindex_txn(txn: LoggingTransaction) -> int:
  240. sql = (
  241. "SELECT stream_ordering, event_id, json FROM events"
  242. " INNER JOIN event_json USING (event_id)"
  243. " WHERE ? <= stream_ordering AND stream_ordering < ?"
  244. " ORDER BY stream_ordering DESC"
  245. " LIMIT ?"
  246. )
  247. txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
  248. rows = txn.fetchall()
  249. if not rows:
  250. return 0
  251. min_stream_id = rows[-1][0]
  252. update_rows = []
  253. for row in rows:
  254. try:
  255. event_id = row[1]
  256. event_json = db_to_json(row[2])
  257. sender = event_json["sender"]
  258. content = event_json["content"]
  259. contains_url = "url" in content
  260. if contains_url:
  261. contains_url &= isinstance(content["url"], str)
  262. except (KeyError, AttributeError):
  263. # If the event is missing a necessary field then
  264. # skip over it.
  265. continue
  266. update_rows.append((sender, contains_url, event_id))
  267. sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
  268. txn.execute_batch(sql, update_rows)
  269. progress = {
  270. "target_min_stream_id_inclusive": target_min_stream_id,
  271. "max_stream_id_exclusive": min_stream_id,
  272. "rows_inserted": rows_inserted + len(rows),
  273. }
  274. self.db_pool.updates._background_update_progress_txn(
  275. txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
  276. )
  277. return len(rows)
  278. result = await self.db_pool.runInteraction(
  279. _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
  280. )
  281. if not result:
  282. await self.db_pool.updates._end_background_update(
  283. _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
  284. )
  285. return result
  286. async def _background_reindex_origin_server_ts(
  287. self, progress: JsonDict, batch_size: int
  288. ) -> int:
  289. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  290. max_stream_id = progress["max_stream_id_exclusive"]
  291. rows_inserted = progress.get("rows_inserted", 0)
  292. def reindex_search_txn(txn: LoggingTransaction) -> int:
  293. sql = (
  294. "SELECT stream_ordering, event_id FROM events"
  295. " WHERE ? <= stream_ordering AND stream_ordering < ?"
  296. " ORDER BY stream_ordering DESC"
  297. " LIMIT ?"
  298. )
  299. txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
  300. rows = txn.fetchall()
  301. if not rows:
  302. return 0
  303. min_stream_id = rows[-1][0]
  304. event_ids = [row[1] for row in rows]
  305. rows_to_update = []
  306. chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
  307. for chunk in chunks:
  308. ev_rows = cast(
  309. List[Tuple[str, str]],
  310. self.db_pool.simple_select_many_txn(
  311. txn,
  312. table="event_json",
  313. column="event_id",
  314. iterable=chunk,
  315. retcols=["event_id", "json"],
  316. keyvalues={},
  317. ),
  318. )
  319. for event_id, json in ev_rows:
  320. event_json = db_to_json(json)
  321. try:
  322. origin_server_ts = event_json["origin_server_ts"]
  323. except (KeyError, AttributeError):
  324. # If the event is missing a necessary field then
  325. # skip over it.
  326. continue
  327. rows_to_update.append((origin_server_ts, event_id))
  328. sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
  329. txn.execute_batch(sql, rows_to_update)
  330. progress = {
  331. "target_min_stream_id_inclusive": target_min_stream_id,
  332. "max_stream_id_exclusive": min_stream_id,
  333. "rows_inserted": rows_inserted + len(rows_to_update),
  334. }
  335. self.db_pool.updates._background_update_progress_txn(
  336. txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress
  337. )
  338. return len(rows_to_update)
  339. result = await self.db_pool.runInteraction(
  340. _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
  341. )
  342. if not result:
  343. await self.db_pool.updates._end_background_update(
  344. _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME
  345. )
  346. return result
  347. async def _cleanup_extremities_bg_update(
  348. self, progress: JsonDict, batch_size: int
  349. ) -> int:
  350. """Background update to clean out extremities that should have been
  351. deleted previously.
  352. Mainly used to deal with the aftermath of https://github.com/matrix-org/synapse/issues/5269.
  353. """
  354. # This works by first copying all existing forward extremities into the
  355. # `_extremities_to_check` table at start up, and then checking each
  356. # event in that table whether we have any descendants that are not
  357. # soft-failed/rejected. If that is the case then we delete that event
  358. # from the forward extremities table.
  359. #
  360. # For efficiency, we do this in batches by recursively pulling out all
  361. # descendants of a batch until we find the non soft-failed/rejected
  362. # events, i.e. the set of descendants whose chain of prev events back
  363. # to the batch of extremities are all soft-failed or rejected.
  364. # Typically, we won't find any such events as extremities will rarely
  365. # have any descendants, but if they do then we should delete those
  366. # extremities.
  367. def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
  368. # The set of extremity event IDs that we're checking this round
  369. original_set = set()
  370. # A dict[str, Set[str]] of event ID to their prev events.
  371. graph: Dict[str, Set[str]] = {}
  372. # The set of descendants of the original set that are not rejected
  373. # nor soft-failed. Ancestors of these events should be removed
  374. # from the forward extremities table.
  375. non_rejected_leaves = set()
  376. # Set of event IDs that have been soft failed, and for which we
  377. # should check if they have descendants which haven't been soft
  378. # failed.
  379. soft_failed_events_to_lookup = set()
  380. # First, we get `batch_size` events from the table, pulling out
  381. # their successor events, if any, and the successor events'
  382. # rejection status.
  383. txn.execute(
  384. """SELECT prev_event_id, event_id, internal_metadata,
  385. rejections.event_id IS NOT NULL, events.outlier
  386. FROM (
  387. SELECT event_id AS prev_event_id
  388. FROM _extremities_to_check
  389. LIMIT ?
  390. ) AS f
  391. LEFT JOIN event_edges USING (prev_event_id)
  392. LEFT JOIN events USING (event_id)
  393. LEFT JOIN event_json USING (event_id)
  394. LEFT JOIN rejections USING (event_id)
  395. """,
  396. (batch_size,),
  397. )
  398. for prev_event_id, event_id, metadata, rejected, outlier in txn:
  399. original_set.add(prev_event_id)
  400. if not event_id or outlier:
  401. # Common case where the forward extremity doesn't have any
  402. # descendants.
  403. continue
  404. graph.setdefault(event_id, set()).add(prev_event_id)
  405. soft_failed = False
  406. if metadata:
  407. soft_failed = db_to_json(metadata).get("soft_failed")
  408. if soft_failed or rejected:
  409. soft_failed_events_to_lookup.add(event_id)
  410. else:
  411. non_rejected_leaves.add(event_id)
  412. # Now we recursively check all the soft-failed descendants we
  413. # found above in the same way, until we have nothing left to
  414. # check.
  415. while soft_failed_events_to_lookup:
  416. # We only want to do 100 at a time, so we split given list
  417. # into two.
  418. batch = list(soft_failed_events_to_lookup)
  419. to_check, to_defer = batch[:100], batch[100:]
  420. soft_failed_events_to_lookup = set(to_defer)
  421. sql = """SELECT prev_event_id, event_id, internal_metadata,
  422. rejections.event_id IS NOT NULL
  423. FROM event_edges
  424. INNER JOIN events USING (event_id)
  425. INNER JOIN event_json USING (event_id)
  426. LEFT JOIN rejections USING (event_id)
  427. WHERE
  428. NOT events.outlier
  429. AND
  430. """
  431. clause, args = make_in_list_sql_clause(
  432. self.database_engine, "prev_event_id", to_check
  433. )
  434. txn.execute(sql + clause, list(args))
  435. for prev_event_id, event_id, metadata, rejected in txn:
  436. if event_id in graph:
  437. # Already handled this event previously, but we still
  438. # want to record the edge.
  439. graph[event_id].add(prev_event_id)
  440. continue
  441. graph[event_id] = {prev_event_id}
  442. soft_failed = db_to_json(metadata).get("soft_failed")
  443. if soft_failed or rejected:
  444. soft_failed_events_to_lookup.add(event_id)
  445. else:
  446. non_rejected_leaves.add(event_id)
  447. # We have a set of non-soft-failed descendants, so we recurse up
  448. # the graph to find all ancestors and add them to the set of event
  449. # IDs that we can delete from forward extremities table.
  450. to_delete = set()
  451. while non_rejected_leaves:
  452. event_id = non_rejected_leaves.pop()
  453. prev_event_ids = graph.get(event_id, set())
  454. non_rejected_leaves.update(prev_event_ids)
  455. to_delete.update(prev_event_ids)
  456. to_delete.intersection_update(original_set)
  457. deleted = self.db_pool.simple_delete_many_txn(
  458. txn=txn,
  459. table="event_forward_extremities",
  460. column="event_id",
  461. values=to_delete,
  462. keyvalues={},
  463. )
  464. logger.info(
  465. "Deleted %d forward extremities of %d checked, to clean up matrix-org/synapse#5269",
  466. deleted,
  467. len(original_set),
  468. )
  469. if deleted:
  470. # We now need to invalidate the caches of these rooms
  471. rows = cast(
  472. List[Tuple[str]],
  473. self.db_pool.simple_select_many_txn(
  474. txn,
  475. table="events",
  476. column="event_id",
  477. iterable=to_delete,
  478. keyvalues={},
  479. retcols=("room_id",),
  480. ),
  481. )
  482. room_ids = {row[0] for row in rows}
  483. for room_id in room_ids:
  484. txn.call_after(
  485. self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
  486. )
  487. self.db_pool.simple_delete_many_txn(
  488. txn=txn,
  489. table="_extremities_to_check",
  490. column="event_id",
  491. values=original_set,
  492. keyvalues={},
  493. )
  494. return len(original_set)
  495. num_handled = await self.db_pool.runInteraction(
  496. "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
  497. )
  498. if not num_handled:
  499. await self.db_pool.updates._end_background_update(
  500. _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
  501. )
  502. def _drop_table_txn(txn: LoggingTransaction) -> None:
  503. txn.execute("DROP TABLE _extremities_to_check")
  504. await self.db_pool.runInteraction(
  505. "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
  506. )
  507. return num_handled
  508. async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
  509. """Handles filling out the `received_ts` column in redactions."""
  510. last_event_id = progress.get("last_event_id", "")
  511. def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
  512. # Fetch the set of event IDs that we want to update
  513. sql = """
  514. SELECT event_id FROM redactions
  515. WHERE event_id > ?
  516. ORDER BY event_id ASC
  517. LIMIT ?
  518. """
  519. txn.execute(sql, (last_event_id, batch_size))
  520. rows = txn.fetchall()
  521. if not rows:
  522. return 0
  523. (upper_event_id,) = rows[-1]
  524. # Update the redactions with the received_ts.
  525. #
  526. # Note: Not all events have an associated received_ts, so we
  527. # fallback to using origin_server_ts. If we for some reason don't
  528. # have an origin_server_ts, lets just use the current timestamp.
  529. #
  530. # We don't want to leave it null, as then we'll never try and
  531. # censor those redactions.
  532. sql = """
  533. UPDATE redactions
  534. SET received_ts = (
  535. SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
  536. WHERE events.event_id = redactions.event_id
  537. )
  538. WHERE ? <= event_id AND event_id <= ?
  539. """
  540. txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
  541. self.db_pool.updates._background_update_progress_txn(
  542. txn, "redactions_received_ts", {"last_event_id": upper_event_id}
  543. )
  544. return len(rows)
  545. count = await self.db_pool.runInteraction(
  546. "_redactions_received_ts", _redactions_received_ts_txn
  547. )
  548. if not count:
  549. await self.db_pool.updates._end_background_update("redactions_received_ts")
  550. return count
  551. async def _event_fix_redactions_bytes(
  552. self, progress: JsonDict, batch_size: int
  553. ) -> int:
  554. """Undoes hex encoded censored redacted event JSON."""
  555. def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
  556. # This update is quite fast due to new index.
  557. txn.execute(
  558. """
  559. UPDATE event_json
  560. SET
  561. json = convert_from(json::bytea, 'utf8')
  562. FROM redactions
  563. WHERE
  564. redactions.have_censored
  565. AND event_json.event_id = redactions.redacts
  566. AND json NOT LIKE '{%';
  567. """
  568. )
  569. txn.execute("DROP INDEX redactions_censored_redacts")
  570. await self.db_pool.runInteraction(
  571. "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
  572. )
  573. await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
  574. return 1
  575. async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
  576. """Background update handler which will store labels for existing events."""
  577. last_event_id = progress.get("last_event_id", "")
  578. def _event_store_labels_txn(txn: LoggingTransaction) -> int:
  579. txn.execute(
  580. """
  581. SELECT event_id, json FROM event_json
  582. LEFT JOIN event_labels USING (event_id)
  583. WHERE event_id > ? AND label IS NULL
  584. ORDER BY event_id LIMIT ?
  585. """,
  586. (last_event_id, batch_size),
  587. )
  588. results = list(txn)
  589. nbrows = 0
  590. last_row_event_id = ""
  591. for event_id, event_json_raw in results:
  592. try:
  593. event_json = db_to_json(event_json_raw)
  594. self.db_pool.simple_insert_many_txn(
  595. txn=txn,
  596. table="event_labels",
  597. keys=("event_id", "label", "room_id", "topological_ordering"),
  598. values=[
  599. (
  600. event_id,
  601. label,
  602. event_json["room_id"],
  603. event_json["depth"],
  604. )
  605. for label in event_json["content"].get(
  606. EventContentFields.LABELS, []
  607. )
  608. if isinstance(label, str)
  609. ],
  610. )
  611. except Exception as e:
  612. logger.warning(
  613. "Unable to load event %s (no labels will be imported): %s",
  614. event_id,
  615. e,
  616. )
  617. nbrows += 1
  618. last_row_event_id = event_id
  619. self.db_pool.updates._background_update_progress_txn(
  620. txn, "event_store_labels", {"last_event_id": last_row_event_id}
  621. )
  622. return nbrows
  623. num_rows = await self.db_pool.runInteraction(
  624. desc="event_store_labels", func=_event_store_labels_txn
  625. )
  626. if not num_rows:
  627. await self.db_pool.updates._end_background_update("event_store_labels")
  628. return num_rows
  629. async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
  630. """Adds rejected events to the `state_events` and `event_auth` metadata
  631. tables.
  632. """
  633. last_event_id = progress.get("last_event_id", "")
  634. def get_rejected_events(
  635. txn: Cursor,
  636. ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
  637. # Fetch rejected event json, their room version and whether we have
  638. # inserted them into the state_events or auth_events tables.
  639. #
  640. # Note we can assume that events that don't have a corresponding
  641. # room version are V1 rooms.
  642. sql = """
  643. SELECT DISTINCT
  644. event_id,
  645. COALESCE(room_version, '1'),
  646. json,
  647. state_events.event_id IS NOT NULL,
  648. event_auth.event_id IS NOT NULL
  649. FROM rejections
  650. INNER JOIN event_json USING (event_id)
  651. LEFT JOIN rooms USING (room_id)
  652. LEFT JOIN state_events USING (event_id)
  653. LEFT JOIN event_auth USING (event_id)
  654. WHERE event_id > ?
  655. ORDER BY event_id
  656. LIMIT ?
  657. """
  658. txn.execute(
  659. sql,
  660. (
  661. last_event_id,
  662. batch_size,
  663. ),
  664. )
  665. return cast(
  666. List[Tuple[str, str, JsonDict, bool, bool]],
  667. [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
  668. )
  669. results = await self.db_pool.runInteraction(
  670. desc="_rejected_events_metadata_get", func=get_rejected_events
  671. )
  672. if not results:
  673. await self.db_pool.updates._end_background_update(
  674. "rejected_events_metadata"
  675. )
  676. return 0
  677. state_events = []
  678. auth_events = []
  679. for event_id, room_version, event_json, has_state, has_event_auth in results:
  680. last_event_id = event_id
  681. if has_state and has_event_auth:
  682. continue
  683. room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
  684. if not room_version_obj:
  685. # We no longer support this room version, so we just ignore the
  686. # events entirely.
  687. logger.info(
  688. "Ignoring event with unknown room version %r: %r",
  689. room_version,
  690. event_id,
  691. )
  692. continue
  693. event = make_event_from_dict(event_json, room_version_obj)
  694. if not event.is_state():
  695. continue
  696. if not has_state:
  697. state_events.append(
  698. (event.event_id, event.room_id, event.type, event.state_key)
  699. )
  700. if not has_event_auth:
  701. # Old, dodgy, events may have duplicate auth events, which we
  702. # need to deduplicate as we have a unique constraint.
  703. for auth_id in set(event.auth_event_ids()):
  704. auth_events.append((event.event_id, event.room_id, auth_id))
  705. if state_events:
  706. await self.db_pool.simple_insert_many(
  707. table="state_events",
  708. keys=("event_id", "room_id", "type", "state_key"),
  709. values=state_events,
  710. desc="_rejected_events_metadata_state_events",
  711. )
  712. if auth_events:
  713. await self.db_pool.simple_insert_many(
  714. table="event_auth",
  715. keys=("event_id", "room_id", "auth_id"),
  716. values=auth_events,
  717. desc="_rejected_events_metadata_event_auth",
  718. )
  719. await self.db_pool.updates._background_update_progress(
  720. "rejected_events_metadata", {"last_event_id": last_event_id}
  721. )
  722. if len(results) < batch_size:
  723. await self.db_pool.updates._end_background_update(
  724. "rejected_events_metadata"
  725. )
  726. return len(results)
  727. async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
  728. """A background updates that iterates over all rooms and generates the
  729. chain cover index for them.
  730. """
  731. current_room_id = progress.get("current_room_id", "")
  732. # Where we've processed up to in the room, defaults to the start of the
  733. # room.
  734. last_depth = progress.get("last_depth", -1)
  735. last_stream = progress.get("last_stream", -1)
  736. result = await self.db_pool.runInteraction(
  737. "_chain_cover_index",
  738. self._calculate_chain_cover_txn,
  739. current_room_id,
  740. last_depth,
  741. last_stream,
  742. batch_size,
  743. single_room=False,
  744. )
  745. finished = result.processed_count == 0
  746. total_rows_processed = result.processed_count
  747. current_room_id = result.room_id
  748. last_depth = result.depth
  749. last_stream = result.stream
  750. for room_id, (depth, stream) in result.finished_room_map.items():
  751. # If we've done all the events in the room we flip the
  752. # `has_auth_chain_index` in the DB. Note that its possible for
  753. # further events to be persisted between the above and setting the
  754. # flag without having the chain cover calculated for them. This is
  755. # fine as a) the code gracefully handles these cases and b) we'll
  756. # calculate them below.
  757. await self.db_pool.simple_update(
  758. table="rooms",
  759. keyvalues={"room_id": room_id},
  760. updatevalues={"has_auth_chain_index": True},
  761. desc="_chain_cover_index",
  762. )
  763. # Handle any events that might have raced with us flipping the
  764. # bit above.
  765. result = await self.db_pool.runInteraction(
  766. "_chain_cover_index",
  767. self._calculate_chain_cover_txn,
  768. room_id,
  769. depth,
  770. stream,
  771. batch_size=None,
  772. single_room=True,
  773. )
  774. total_rows_processed += result.processed_count
  775. if finished:
  776. await self.db_pool.updates._end_background_update("chain_cover")
  777. return total_rows_processed
  778. await self.db_pool.updates._background_update_progress(
  779. "chain_cover",
  780. {
  781. "current_room_id": current_room_id,
  782. "last_depth": last_depth,
  783. "last_stream": last_stream,
  784. },
  785. )
  786. return total_rows_processed
  787. def _calculate_chain_cover_txn(
  788. self,
  789. txn: LoggingTransaction,
  790. last_room_id: str,
  791. last_depth: int,
  792. last_stream: int,
  793. batch_size: Optional[int],
  794. single_room: bool,
  795. ) -> _CalculateChainCover:
  796. """Calculate the chain cover for `batch_size` events, ordered by
  797. `(room_id, depth, stream)`.
  798. Args:
  799. txn,
  800. last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
  801. tuple to fetch results after.
  802. batch_size: The maximum number of events to process. If None then
  803. no limit.
  804. single_room: Whether to calculate the index for just the given
  805. room.
  806. """
  807. # Get the next set of events in the room (that we haven't already
  808. # computed chain cover for). We do this in topological order.
  809. # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
  810. # comparison, but that is not supported on older SQLite versions
  811. tuple_clause, tuple_args = make_tuple_comparison_clause(
  812. [
  813. ("events.room_id", last_room_id),
  814. ("topological_ordering", last_depth),
  815. ("stream_ordering", last_stream),
  816. ],
  817. )
  818. extra_clause = ""
  819. if single_room:
  820. extra_clause = "AND events.room_id = ?"
  821. tuple_args.append(last_room_id)
  822. sql = """
  823. SELECT
  824. event_id, state_events.type, state_events.state_key,
  825. topological_ordering, stream_ordering,
  826. events.room_id
  827. FROM events
  828. INNER JOIN state_events USING (event_id)
  829. LEFT JOIN event_auth_chains USING (event_id)
  830. LEFT JOIN event_auth_chain_to_calculate USING (event_id)
  831. WHERE event_auth_chains.event_id IS NULL
  832. AND event_auth_chain_to_calculate.event_id IS NULL
  833. AND %(tuple_cmp)s
  834. %(extra)s
  835. ORDER BY events.room_id, topological_ordering, stream_ordering
  836. %(limit)s
  837. """ % {
  838. "tuple_cmp": tuple_clause,
  839. "limit": "LIMIT ?" if batch_size is not None else "",
  840. "extra": extra_clause,
  841. }
  842. if batch_size is not None:
  843. tuple_args.append(batch_size)
  844. txn.execute(sql, tuple_args)
  845. rows = txn.fetchall()
  846. # Put the results in the necessary format for
  847. # `_add_chain_cover_index`
  848. event_to_room_id = {row[0]: row[5] for row in rows}
  849. event_to_types = {row[0]: (row[1], row[2]) for row in rows}
  850. # Calculate the new last position we've processed up to.
  851. new_last_depth: int = rows[-1][3] if rows else last_depth
  852. new_last_stream: int = rows[-1][4] if rows else last_stream
  853. new_last_room_id: str = rows[-1][5] if rows else ""
  854. # Map from room_id to last depth/stream_ordering processed for the room,
  855. # excluding the last room (which we're likely still processing). We also
  856. # need to include the room passed in if it's not included in the result
  857. # set (as we then know we've processed all events in said room).
  858. #
  859. # This is the set of rooms that we can now safely flip the
  860. # `has_auth_chain_index` bit for.
  861. finished_rooms = {
  862. row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
  863. }
  864. if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
  865. finished_rooms[last_room_id] = (last_depth, last_stream)
  866. count = len(rows)
  867. # We also need to fetch the auth events for them.
  868. auth_events = cast(
  869. List[Tuple[str, str]],
  870. self.db_pool.simple_select_many_txn(
  871. txn,
  872. table="event_auth",
  873. column="event_id",
  874. iterable=event_to_room_id,
  875. keyvalues={},
  876. retcols=("event_id", "auth_id"),
  877. ),
  878. )
  879. event_to_auth_chain: Dict[str, List[str]] = {}
  880. for event_id, auth_id in auth_events:
  881. event_to_auth_chain.setdefault(event_id, []).append(auth_id)
  882. # Calculate and persist the chain cover index for this set of events.
  883. #
  884. # Annoyingly we need to gut wrench into the persit event store so that
  885. # we can reuse the function to calculate the chain cover for rooms.
  886. PersistEventsStore._add_chain_cover_index(
  887. txn,
  888. self.db_pool,
  889. self.event_chain_id_gen, # type: ignore[attr-defined]
  890. event_to_room_id,
  891. event_to_types,
  892. cast(Dict[str, StrCollection], event_to_auth_chain),
  893. )
  894. return _CalculateChainCover(
  895. room_id=new_last_room_id,
  896. depth=new_last_depth,
  897. stream=new_last_stream,
  898. processed_count=count,
  899. finished_room_map=finished_rooms,
  900. )
  901. async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
  902. """
  903. A background updates that iterates over the chain cover and deletes the
  904. chain cover for events that have been purged.
  905. This may be due to fully purging a room or via setting a retention policy.
  906. """
  907. current_event_id = progress.get("current_event_id", "")
  908. def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
  909. # The event ID from events will be null if the chain ID / sequence
  910. # number points to a purged event.
  911. sql = """
  912. SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
  913. FROM event_auth_chains
  914. LEFT JOIN events AS e USING (event_id)
  915. WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
  916. """
  917. txn.execute(sql, (current_event_id, batch_size))
  918. rows = txn.fetchall()
  919. if not rows:
  920. return 0
  921. # The event IDs and chain IDs / sequence numbers where the event has
  922. # been purged.
  923. unreferenced_event_ids = []
  924. unreferenced_chain_id_tuples = []
  925. event_id = ""
  926. for event_id, chain_id, sequence_number, has_event in rows:
  927. if not has_event:
  928. unreferenced_event_ids.append((event_id,))
  929. unreferenced_chain_id_tuples.append((chain_id, sequence_number))
  930. # Delete the unreferenced auth chains from event_auth_chain_links and
  931. # event_auth_chains.
  932. txn.executemany(
  933. """
  934. DELETE FROM event_auth_chains WHERE event_id = ?
  935. """,
  936. unreferenced_event_ids,
  937. )
  938. # We should also delete matching target_*, but there is no index on
  939. # target_chain_id. Hopefully any purged events are due to a room
  940. # being fully purged and they will be removed from the origin_*
  941. # searches.
  942. txn.executemany(
  943. """
  944. DELETE FROM event_auth_chain_links WHERE
  945. origin_chain_id = ? AND origin_sequence_number = ?
  946. """,
  947. unreferenced_chain_id_tuples,
  948. )
  949. progress = {
  950. "current_event_id": event_id,
  951. }
  952. self.db_pool.updates._background_update_progress_txn(
  953. txn, "purged_chain_cover", progress
  954. )
  955. return len(rows)
  956. result = await self.db_pool.runInteraction(
  957. "_purged_chain_cover_index",
  958. purged_chain_cover_txn,
  959. )
  960. if not result:
  961. await self.db_pool.updates._end_background_update("purged_chain_cover")
  962. return result
  963. async def _event_arbitrary_relations(
  964. self, progress: JsonDict, batch_size: int
  965. ) -> int:
  966. """Background update handler which will store previously unknown relations for existing events."""
  967. last_event_id = progress.get("last_event_id", "")
  968. def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
  969. # Fetch events and then filter based on whether the event has a
  970. # relation or not.
  971. txn.execute(
  972. """
  973. SELECT event_id, json FROM event_json
  974. WHERE event_id > ?
  975. ORDER BY event_id LIMIT ?
  976. """,
  977. (last_event_id, batch_size),
  978. )
  979. results = list(txn)
  980. # (event_id, parent_id, rel_type) for each relation
  981. relations_to_insert: List[Tuple[str, str, str]] = []
  982. for event_id, event_json_raw in results:
  983. try:
  984. event_json = db_to_json(event_json_raw)
  985. except Exception as e:
  986. logger.warning(
  987. "Unable to load event %s (no relations will be updated): %s",
  988. event_id,
  989. e,
  990. )
  991. continue
  992. # If there's no relation, skip!
  993. relates_to = event_json["content"].get("m.relates_to")
  994. if not relates_to or not isinstance(relates_to, dict):
  995. continue
  996. # If the relation type or parent event ID is not a string, skip it.
  997. #
  998. # Do not consider relation types that have existed for a long time,
  999. # since they will already be listed in the `event_relations` table.
  1000. rel_type = relates_to.get("rel_type")
  1001. if not isinstance(rel_type, str) or rel_type in (
  1002. RelationTypes.ANNOTATION,
  1003. RelationTypes.REFERENCE,
  1004. RelationTypes.REPLACE,
  1005. ):
  1006. continue
  1007. parent_id = relates_to.get("event_id")
  1008. if not isinstance(parent_id, str):
  1009. continue
  1010. relations_to_insert.append((event_id, parent_id, rel_type))
  1011. # Insert the missing data, note that we upsert here in case the event
  1012. # has already been processed.
  1013. if relations_to_insert:
  1014. self.db_pool.simple_upsert_many_txn(
  1015. txn=txn,
  1016. table="event_relations",
  1017. key_names=("event_id",),
  1018. key_values=[(r[0],) for r in relations_to_insert],
  1019. value_names=("relates_to_id", "relation_type"),
  1020. value_values=[r[1:] for r in relations_to_insert],
  1021. )
  1022. # Iterate the parent IDs and invalidate caches.
  1023. cache_tuples = {(r[1],) for r in relations_to_insert}
  1024. self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
  1025. txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
  1026. )
  1027. self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
  1028. txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
  1029. )
  1030. if results:
  1031. latest_event_id = results[-1][0]
  1032. self.db_pool.updates._background_update_progress_txn(
  1033. txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
  1034. )
  1035. return len(results)
  1036. num_rows = await self.db_pool.runInteraction(
  1037. desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
  1038. )
  1039. if not num_rows:
  1040. await self.db_pool.updates._end_background_update(
  1041. "event_arbitrary_relations"
  1042. )
  1043. return num_rows
  1044. async def _background_populate_stream_ordering2(
  1045. self, progress: JsonDict, batch_size: int
  1046. ) -> int:
  1047. """Populate events.stream_ordering2, then replace stream_ordering
  1048. This is to deal with the fact that stream_ordering was initially created as a
  1049. 32-bit integer field.
  1050. """
  1051. batch_size = max(batch_size, 1)
  1052. def process(txn: LoggingTransaction) -> int:
  1053. last_stream = progress.get("last_stream", -(1 << 31))
  1054. txn.execute(
  1055. """
  1056. UPDATE events SET stream_ordering2=stream_ordering
  1057. WHERE stream_ordering IN (
  1058. SELECT stream_ordering FROM events WHERE stream_ordering > ?
  1059. ORDER BY stream_ordering LIMIT ?
  1060. )
  1061. RETURNING stream_ordering;
  1062. """,
  1063. (last_stream, batch_size),
  1064. )
  1065. row_count = txn.rowcount
  1066. if row_count == 0:
  1067. return 0
  1068. last_stream = max(row[0] for row in txn)
  1069. logger.info("populated stream_ordering2 up to %i", last_stream)
  1070. self.db_pool.updates._background_update_progress_txn(
  1071. txn,
  1072. _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
  1073. {"last_stream": last_stream},
  1074. )
  1075. return row_count
  1076. result = await self.db_pool.runInteraction(
  1077. "_background_populate_stream_ordering2", process
  1078. )
  1079. if result != 0:
  1080. return result
  1081. await self.db_pool.updates._end_background_update(
  1082. _BackgroundUpdates.POPULATE_STREAM_ORDERING2
  1083. )
  1084. return 0
  1085. async def _background_replace_stream_ordering_column(
  1086. self, progress: JsonDict, batch_size: int
  1087. ) -> int:
  1088. """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place."""
  1089. def process(txn: Cursor) -> None:
  1090. for sql in _REPLACE_STREAM_ORDERING_SQL_COMMANDS:
  1091. logger.info("completing stream_ordering migration: %s", sql)
  1092. txn.execute(sql)
  1093. # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
  1094. # indexes on it.
  1095. await self.db_pool.runInteraction(
  1096. "background_analyze_new_stream_ordering_column",
  1097. lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
  1098. )
  1099. await self.db_pool.runInteraction(
  1100. "_background_replace_stream_ordering_column", process
  1101. )
  1102. await self.db_pool.updates._end_background_update(
  1103. _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN
  1104. )
  1105. return 0
  1106. async def _background_drop_invalid_event_edges_rows(
  1107. self, progress: JsonDict, batch_size: int
  1108. ) -> int:
  1109. """Drop invalid rows from event_edges
  1110. This only runs for postgres. For SQLite, it all happens synchronously.
  1111. Firstly, drop any rows with is_state=True. These may have been added a long time
  1112. ago, but they are no longer used.
  1113. We also drop rows that do not correspond to entries in `events`, and add a
  1114. foreign key.
  1115. """
  1116. last_event_id = progress.get("last_event_id", "")
  1117. def drop_invalid_event_edges_txn(txn: LoggingTransaction) -> bool:
  1118. """Returns True if we're done."""
  1119. # first we need to find an endpoint.
  1120. txn.execute(
  1121. """
  1122. SELECT event_id FROM event_edges
  1123. WHERE event_id > ?
  1124. ORDER BY event_id
  1125. LIMIT 1 OFFSET ?
  1126. """,
  1127. (last_event_id, batch_size),
  1128. )
  1129. endpoint = None
  1130. row = txn.fetchone()
  1131. if row:
  1132. endpoint = row[0]
  1133. where_clause = "ee.event_id > ?"
  1134. args = [last_event_id]
  1135. if endpoint:
  1136. where_clause += " AND ee.event_id <= ?"
  1137. args.append(endpoint)
  1138. # now delete any that:
  1139. # - have is_state=TRUE, or
  1140. # - do not correspond to a row in `events`
  1141. txn.execute(
  1142. f"""
  1143. DELETE FROM event_edges
  1144. WHERE event_id IN (
  1145. SELECT ee.event_id
  1146. FROM event_edges ee
  1147. LEFT JOIN events ev USING (event_id)
  1148. WHERE ({where_clause}) AND
  1149. (is_state OR ev.event_id IS NULL)
  1150. )""",
  1151. args,
  1152. )
  1153. logger.info(
  1154. "cleaned up event_edges up to %s: removed %i/%i rows",
  1155. endpoint,
  1156. txn.rowcount,
  1157. batch_size,
  1158. )
  1159. if endpoint is not None:
  1160. self.db_pool.updates._background_update_progress_txn(
  1161. txn,
  1162. _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
  1163. {"last_event_id": endpoint},
  1164. )
  1165. return False
  1166. # if that was the final batch, we validate the foreign key.
  1167. #
  1168. # The constraint should have been in place and enforced for new rows since
  1169. # before we started deleting invalid rows, so there's no chance for any
  1170. # invalid rows to have snuck in the meantime. In other words, this really
  1171. # ought to succeed.
  1172. logger.info("cleaned up event_edges; enabling foreign key")
  1173. txn.execute(
  1174. "ALTER TABLE event_edges VALIDATE CONSTRAINT event_edges_event_id_fkey"
  1175. )
  1176. return True
  1177. done = await self.db_pool.runInteraction(
  1178. desc="drop_invalid_event_edges", func=drop_invalid_event_edges_txn
  1179. )
  1180. if done:
  1181. await self.db_pool.updates._end_background_update(
  1182. _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS
  1183. )
  1184. return batch_size
  1185. async def _background_events_populate_state_key_rejections(
  1186. self, progress: JsonDict, batch_size: int
  1187. ) -> int:
  1188. """Back-populate `events.state_key` and `events.rejection_reason"""
  1189. min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
  1190. max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
  1191. def _populate_txn(txn: LoggingTransaction) -> bool:
  1192. """Returns True if we're done."""
  1193. # first we need to find an endpoint.
  1194. # we need to find the final row in the batch of batch_size, which means
  1195. # we need to skip over (batch_size-1) rows and get the next row.
  1196. txn.execute(
  1197. """
  1198. SELECT stream_ordering FROM events
  1199. WHERE stream_ordering > ? AND stream_ordering <= ?
  1200. ORDER BY stream_ordering
  1201. LIMIT 1 OFFSET ?
  1202. """,
  1203. (
  1204. min_stream_ordering_exclusive,
  1205. max_stream_ordering_inclusive,
  1206. batch_size - 1,
  1207. ),
  1208. )
  1209. row = txn.fetchone()
  1210. if row:
  1211. endpoint = row[0]
  1212. else:
  1213. # if the query didn't return a row, we must be almost done. We just
  1214. # need to go up to the recorded max_stream_ordering.
  1215. endpoint = max_stream_ordering_inclusive
  1216. where_clause = "stream_ordering > ? AND stream_ordering <= ?"
  1217. args = [min_stream_ordering_exclusive, endpoint]
  1218. # now do the updates.
  1219. txn.execute(
  1220. f"""
  1221. UPDATE events
  1222. SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
  1223. rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
  1224. WHERE ({where_clause})
  1225. """,
  1226. args,
  1227. )
  1228. logger.info(
  1229. "populated new `events` columns up to %i/%i: updated %i rows",
  1230. endpoint,
  1231. max_stream_ordering_inclusive,
  1232. txn.rowcount,
  1233. )
  1234. if endpoint >= max_stream_ordering_inclusive:
  1235. # we're done
  1236. return True
  1237. progress["min_stream_ordering_exclusive"] = endpoint
  1238. self.db_pool.updates._background_update_progress_txn(
  1239. txn,
  1240. _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
  1241. progress,
  1242. )
  1243. return False
  1244. done = await self.db_pool.runInteraction(
  1245. desc="events_populate_state_key_rejections", func=_populate_txn
  1246. )
  1247. if done:
  1248. await self.db_pool.updates._end_background_update(
  1249. _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
  1250. )
  1251. return batch_size