You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

520 lines
21 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
  16. from synapse.logging.opentracing import tag_args, trace
  17. from synapse.storage._base import SQLBaseStore
  18. from synapse.storage.database import (
  19. DatabasePool,
  20. LoggingDatabaseConnection,
  21. LoggingTransaction,
  22. )
  23. from synapse.storage.engines import PostgresEngine
  24. from synapse.types import MutableStateMap, StateMap
  25. from synapse.types.state import StateFilter
  26. from synapse.util.caches import intern_string
  27. if TYPE_CHECKING:
  28. from synapse.server import HomeServer
  29. logger = logging.getLogger(__name__)
  30. MAX_STATE_DELTA_HOPS = 100
  31. class StateGroupBackgroundUpdateStore(SQLBaseStore):
  32. """Defines functions related to state groups needed to run the state background
  33. updates.
  34. """
  35. @trace
  36. @tag_args
  37. def _count_state_group_hops_txn(
  38. self, txn: LoggingTransaction, state_group: int
  39. ) -> int:
  40. """Given a state group, count how many hops there are in the tree.
  41. This is used to ensure the delta chains don't get too long.
  42. """
  43. if isinstance(self.database_engine, PostgresEngine):
  44. sql = """
  45. WITH RECURSIVE state(state_group) AS (
  46. VALUES(?::bigint)
  47. UNION ALL
  48. SELECT prev_state_group FROM state_group_edges e, state s
  49. WHERE s.state_group = e.state_group
  50. )
  51. SELECT count(*) FROM state;
  52. """
  53. txn.execute(sql, (state_group,))
  54. row = txn.fetchone()
  55. if row and row[0]:
  56. return row[0]
  57. else:
  58. return 0
  59. else:
  60. # We don't use WITH RECURSIVE on sqlite3 as there are distributions
  61. # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
  62. next_group: Optional[int] = state_group
  63. count = 0
  64. while next_group:
  65. next_group = self.db_pool.simple_select_one_onecol_txn(
  66. txn,
  67. table="state_group_edges",
  68. keyvalues={"state_group": next_group},
  69. retcol="prev_state_group",
  70. allow_none=True,
  71. )
  72. if next_group:
  73. count += 1
  74. return count
  75. @trace
  76. @tag_args
  77. def _get_state_groups_from_groups_txn(
  78. self,
  79. txn: LoggingTransaction,
  80. groups: List[int],
  81. state_filter: Optional[StateFilter] = None,
  82. ) -> Mapping[int, StateMap[str]]:
  83. """
  84. Given a number of state groups, fetch the latest state for each group.
  85. Args:
  86. txn: The transaction object.
  87. groups: The given state groups that you want to fetch the latest state for.
  88. state_filter: The state filter to apply the state we fetch state from the database.
  89. Returns:
  90. Map from state_group to a StateMap at that point.
  91. """
  92. state_filter = state_filter or StateFilter.all()
  93. results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
  94. if isinstance(self.database_engine, PostgresEngine):
  95. # Temporarily disable sequential scans in this transaction. This is
  96. # a temporary hack until we can add the right indices in
  97. txn.execute("SET LOCAL enable_seqscan=off")
  98. # The below query walks the state_group tree so that the "state"
  99. # table includes all state_groups in the tree. It then joins
  100. # against `state_groups_state` to fetch the latest state.
  101. # It assumes that previous state groups are always numerically
  102. # lesser.
  103. # This may return multiple rows per (type, state_key), but last_value
  104. # should be the same.
  105. sql = """
  106. WITH RECURSIVE sgs(state_group) AS (
  107. VALUES(?::bigint)
  108. UNION ALL
  109. SELECT prev_state_group FROM state_group_edges e, sgs s
  110. WHERE s.state_group = e.state_group
  111. )
  112. %s
  113. """
  114. overall_select_query_args: List[Union[int, str]] = []
  115. # This is an optimization to create a select clause per-condition. This
  116. # makes the query planner a lot smarter on what rows should pull out in the
  117. # first place and we end up with something that takes 10x less time to get a
  118. # result.
  119. use_condition_optimization = (
  120. not state_filter.include_others and not state_filter.is_full()
  121. )
  122. state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
  123. # We don't need to caclculate this list if we're not using the condition
  124. # optimization
  125. if use_condition_optimization:
  126. for etype, state_keys in state_filter.types.items():
  127. if state_keys is None:
  128. state_filter_condition_combos.append((etype, None))
  129. else:
  130. for state_key in state_keys:
  131. state_filter_condition_combos.append((etype, state_key))
  132. # And here is the optimization itself. We don't want to do the optimization
  133. # if there are too many individual conditions. 10 is an arbitrary number
  134. # with no testing behind it but we do know that we specifically made this
  135. # optimization for when we grab the necessary state out for
  136. # `filter_events_for_client` which just uses 2 conditions
  137. # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
  138. if use_condition_optimization and len(state_filter_condition_combos) < 10:
  139. select_clause_list: List[str] = []
  140. for etype, skey in state_filter_condition_combos:
  141. if skey is None:
  142. where_clause = "(type = ?)"
  143. overall_select_query_args.extend([etype])
  144. else:
  145. where_clause = "(type = ? AND state_key = ?)"
  146. overall_select_query_args.extend([etype, skey])
  147. select_clause_list.append(
  148. f"""
  149. (
  150. SELECT DISTINCT ON (type, state_key)
  151. type, state_key, event_id
  152. FROM state_groups_state
  153. INNER JOIN sgs USING (state_group)
  154. WHERE {where_clause}
  155. ORDER BY type, state_key, state_group DESC
  156. )
  157. """
  158. )
  159. overall_select_clause = " UNION ".join(select_clause_list)
  160. else:
  161. where_clause, where_args = state_filter.make_sql_filter_clause()
  162. # Unless the filter clause is empty, we're going to append it after an
  163. # existing where clause
  164. if where_clause:
  165. where_clause = " AND (%s)" % (where_clause,)
  166. overall_select_query_args.extend(where_args)
  167. overall_select_clause = f"""
  168. SELECT DISTINCT ON (type, state_key)
  169. type, state_key, event_id
  170. FROM state_groups_state
  171. WHERE state_group IN (
  172. SELECT state_group FROM sgs
  173. ) {where_clause}
  174. ORDER BY type, state_key, state_group DESC
  175. """
  176. for group in groups:
  177. args: List[Union[int, str]] = [group]
  178. args.extend(overall_select_query_args)
  179. txn.execute(sql % (overall_select_clause,), args)
  180. for row in txn:
  181. typ, state_key, event_id = row
  182. key = (intern_string(typ), intern_string(state_key))
  183. results[group][key] = event_id
  184. else:
  185. max_entries_returned = state_filter.max_entries_returned()
  186. where_clause, where_args = state_filter.make_sql_filter_clause()
  187. # Unless the filter clause is empty, we're going to append it after an
  188. # existing where clause
  189. if where_clause:
  190. where_clause = " AND (%s)" % (where_clause,)
  191. # XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3
  192. # or higher and our minimum supported version is greater than that.
  193. #
  194. # We just haven't put in the time to refactor this.
  195. for group in groups:
  196. next_group: Optional[int] = group
  197. while next_group:
  198. # We did this before by getting the list of group ids, and
  199. # then passing that list to sqlite to get latest event for
  200. # each (type, state_key). However, that was terribly slow
  201. # without the right indices (which we can't add until
  202. # after we finish deduping state, which requires this func)
  203. args = [next_group]
  204. args.extend(where_args)
  205. txn.execute(
  206. "SELECT type, state_key, event_id FROM state_groups_state"
  207. " WHERE state_group = ? " + where_clause,
  208. args,
  209. )
  210. results[group].update(
  211. ((typ, state_key), event_id)
  212. for typ, state_key, event_id in txn
  213. if (typ, state_key) not in results[group]
  214. )
  215. # If the number of entries in the (type,state_key)->event_id dict
  216. # matches the number of (type,state_keys) types we were searching
  217. # for, then we must have found them all, so no need to go walk
  218. # further down the tree... UNLESS our types filter contained
  219. # wildcards (i.e. Nones) in which case we have to do an exhaustive
  220. # search
  221. if (
  222. max_entries_returned is not None
  223. and len(results[group]) == max_entries_returned
  224. ):
  225. break
  226. next_group = self.db_pool.simple_select_one_onecol_txn(
  227. txn,
  228. table="state_group_edges",
  229. keyvalues={"state_group": next_group},
  230. retcol="prev_state_group",
  231. allow_none=True,
  232. )
  233. # The results shouldn't be considered mutable.
  234. return results
  235. class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
  236. STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
  237. STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
  238. STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
  239. STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx"
  240. CURRENT_STATE_EVENTS_STREAM_ORDERING_INDEX_UPDATE_NAME = (
  241. "current_state_events_stream_ordering_idx"
  242. )
  243. ROOM_MEMBERSHIPS_STREAM_ORDERING_INDEX_UPDATE_NAME = (
  244. "room_memberships_stream_ordering_idx"
  245. )
  246. LOCAL_CURRENT_MEMBERSHIP_STREAM_ORDERING_INDEX_UPDATE_NAME = (
  247. "local_current_membership_stream_ordering_idx"
  248. )
  249. def __init__(
  250. self,
  251. database: DatabasePool,
  252. db_conn: LoggingDatabaseConnection,
  253. hs: "HomeServer",
  254. ):
  255. super().__init__(database, db_conn, hs)
  256. self.db_pool.updates.register_background_update_handler(
  257. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
  258. self._background_deduplicate_state,
  259. )
  260. self.db_pool.updates.register_background_update_handler(
  261. self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
  262. )
  263. self.db_pool.updates.register_background_index_update(
  264. self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
  265. index_name="state_groups_room_id_idx",
  266. table="state_groups",
  267. columns=["room_id"],
  268. )
  269. # `state_group_edges` can cause severe performance issues if duplicate
  270. # rows are introduced, which can accidentally be done by well-meaning
  271. # server admins when trying to restore a database dump, etc.
  272. # See https://github.com/matrix-org/synapse/issues/11779.
  273. # Introduce a unique index to guard against that.
  274. self.db_pool.updates.register_background_index_update(
  275. self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME,
  276. index_name="state_group_edges_unique_idx",
  277. table="state_group_edges",
  278. columns=["state_group", "prev_state_group"],
  279. unique=True,
  280. # The old index was on (state_group) and was not unique.
  281. replaces_index="state_group_edges_idx",
  282. )
  283. # These indices are needed to validate the foreign key constraint
  284. # when events are deleted.
  285. self.db_pool.updates.register_background_index_update(
  286. self.CURRENT_STATE_EVENTS_STREAM_ORDERING_INDEX_UPDATE_NAME,
  287. index_name="current_state_events_stream_ordering_idx",
  288. table="current_state_events",
  289. columns=["event_stream_ordering"],
  290. )
  291. self.db_pool.updates.register_background_index_update(
  292. self.ROOM_MEMBERSHIPS_STREAM_ORDERING_INDEX_UPDATE_NAME,
  293. index_name="room_memberships_stream_ordering_idx",
  294. table="room_memberships",
  295. columns=["event_stream_ordering"],
  296. )
  297. self.db_pool.updates.register_background_index_update(
  298. self.LOCAL_CURRENT_MEMBERSHIP_STREAM_ORDERING_INDEX_UPDATE_NAME,
  299. index_name="local_current_membership_stream_ordering_idx",
  300. table="local_current_membership",
  301. columns=["event_stream_ordering"],
  302. )
  303. async def _background_deduplicate_state(
  304. self, progress: dict, batch_size: int
  305. ) -> int:
  306. """This background update will slowly deduplicate state by reencoding
  307. them as deltas.
  308. """
  309. last_state_group = progress.get("last_state_group", 0)
  310. rows_inserted = progress.get("rows_inserted", 0)
  311. max_group = progress.get("max_group", None)
  312. BATCH_SIZE_SCALE_FACTOR = 100
  313. batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
  314. if max_group is None:
  315. rows = await self.db_pool.execute(
  316. "_background_deduplicate_state",
  317. "SELECT coalesce(max(id), 0) FROM state_groups",
  318. )
  319. max_group = rows[0][0]
  320. def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
  321. new_last_state_group = last_state_group
  322. for count in range(batch_size):
  323. txn.execute(
  324. "SELECT id, room_id FROM state_groups"
  325. " WHERE ? < id AND id <= ?"
  326. " ORDER BY id ASC"
  327. " LIMIT 1",
  328. (new_last_state_group, max_group),
  329. )
  330. row = txn.fetchone()
  331. if row:
  332. state_group, room_id = row
  333. if not row or not state_group:
  334. return True, count
  335. txn.execute(
  336. "SELECT state_group FROM state_group_edges"
  337. " WHERE state_group = ?",
  338. (state_group,),
  339. )
  340. # If we reach a point where we've already started inserting
  341. # edges we should stop.
  342. if txn.fetchall():
  343. return True, count
  344. txn.execute(
  345. "SELECT coalesce(max(id), 0) FROM state_groups"
  346. " WHERE id < ? AND room_id = ?",
  347. (state_group, room_id),
  348. )
  349. # There will be a result due to the coalesce.
  350. (prev_group,) = txn.fetchone() # type: ignore
  351. new_last_state_group = state_group
  352. if prev_group:
  353. potential_hops = self._count_state_group_hops_txn(txn, prev_group)
  354. if potential_hops >= MAX_STATE_DELTA_HOPS:
  355. # We want to ensure chains are at most this long,#
  356. # otherwise read performance degrades.
  357. continue
  358. prev_state_by_group = self._get_state_groups_from_groups_txn(
  359. txn, [prev_group]
  360. )
  361. prev_state = prev_state_by_group[prev_group]
  362. curr_state_by_group = self._get_state_groups_from_groups_txn(
  363. txn, [state_group]
  364. )
  365. curr_state = curr_state_by_group[state_group]
  366. if not set(prev_state.keys()) - set(curr_state.keys()):
  367. # We can only do a delta if the current has a strict super set
  368. # of keys
  369. delta_state = {
  370. key: value
  371. for key, value in curr_state.items()
  372. if prev_state.get(key, None) != value
  373. }
  374. self.db_pool.simple_delete_txn(
  375. txn,
  376. table="state_group_edges",
  377. keyvalues={"state_group": state_group},
  378. )
  379. self.db_pool.simple_insert_txn(
  380. txn,
  381. table="state_group_edges",
  382. values={
  383. "state_group": state_group,
  384. "prev_state_group": prev_group,
  385. },
  386. )
  387. self.db_pool.simple_delete_txn(
  388. txn,
  389. table="state_groups_state",
  390. keyvalues={"state_group": state_group},
  391. )
  392. self.db_pool.simple_insert_many_txn(
  393. txn,
  394. table="state_groups_state",
  395. keys=(
  396. "state_group",
  397. "room_id",
  398. "type",
  399. "state_key",
  400. "event_id",
  401. ),
  402. values=[
  403. (state_group, room_id, key[0], key[1], state_id)
  404. for key, state_id in delta_state.items()
  405. ],
  406. )
  407. progress = {
  408. "last_state_group": state_group,
  409. "rows_inserted": rows_inserted + batch_size,
  410. "max_group": max_group,
  411. }
  412. self.db_pool.updates._background_update_progress_txn(
  413. txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
  414. )
  415. return False, batch_size
  416. finished, result = await self.db_pool.runInteraction(
  417. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
  418. )
  419. if finished:
  420. await self.db_pool.updates._end_background_update(
  421. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
  422. )
  423. return result * BATCH_SIZE_SCALE_FACTOR
  424. async def _background_index_state(self, progress: dict, batch_size: int) -> int:
  425. def reindex_txn(conn: LoggingDatabaseConnection) -> None:
  426. conn.rollback()
  427. if isinstance(self.database_engine, PostgresEngine):
  428. # postgres insists on autocommit for the index
  429. conn.engine.attempt_to_set_autocommit(conn.conn, True)
  430. try:
  431. txn = conn.cursor()
  432. txn.execute(
  433. "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
  434. " ON state_groups_state(state_group, type, state_key)"
  435. )
  436. txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
  437. finally:
  438. conn.engine.attempt_to_set_autocommit(conn.conn, False)
  439. else:
  440. txn = conn.cursor()
  441. txn.execute(
  442. "CREATE INDEX state_groups_state_type_idx"
  443. " ON state_groups_state(state_group, type, state_key)"
  444. )
  445. txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
  446. await self.db_pool.runWithConnection(reindex_txn)
  447. await self.db_pool.updates._end_background_update(
  448. self.STATE_GROUP_INDEX_UPDATE_NAME
  449. )
  450. return 1