You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

471 lines
19 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
  16. from synapse.storage._base import SQLBaseStore
  17. from synapse.storage.database import (
  18. DatabasePool,
  19. LoggingDatabaseConnection,
  20. LoggingTransaction,
  21. )
  22. from synapse.storage.engines import PostgresEngine
  23. from synapse.types import MutableStateMap, StateMap
  24. from synapse.types.state import StateFilter
  25. from synapse.util.caches import intern_string
  26. if TYPE_CHECKING:
  27. from synapse.server import HomeServer
  28. logger = logging.getLogger(__name__)
  29. MAX_STATE_DELTA_HOPS = 100
  30. class StateGroupBackgroundUpdateStore(SQLBaseStore):
  31. """Defines functions related to state groups needed to run the state background
  32. updates.
  33. """
  34. def _count_state_group_hops_txn(
  35. self, txn: LoggingTransaction, state_group: int
  36. ) -> int:
  37. """Given a state group, count how many hops there are in the tree.
  38. This is used to ensure the delta chains don't get too long.
  39. """
  40. if isinstance(self.database_engine, PostgresEngine):
  41. sql = """
  42. WITH RECURSIVE state(state_group) AS (
  43. VALUES(?::bigint)
  44. UNION ALL
  45. SELECT prev_state_group FROM state_group_edges e, state s
  46. WHERE s.state_group = e.state_group
  47. )
  48. SELECT count(*) FROM state;
  49. """
  50. txn.execute(sql, (state_group,))
  51. row = txn.fetchone()
  52. if row and row[0]:
  53. return row[0]
  54. else:
  55. return 0
  56. else:
  57. # We don't use WITH RECURSIVE on sqlite3 as there are distributions
  58. # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
  59. next_group: Optional[int] = state_group
  60. count = 0
  61. while next_group:
  62. next_group = self.db_pool.simple_select_one_onecol_txn(
  63. txn,
  64. table="state_group_edges",
  65. keyvalues={"state_group": next_group},
  66. retcol="prev_state_group",
  67. allow_none=True,
  68. )
  69. if next_group:
  70. count += 1
  71. return count
  72. def _get_state_groups_from_groups_txn(
  73. self,
  74. txn: LoggingTransaction,
  75. groups: List[int],
  76. state_filter: Optional[StateFilter] = None,
  77. ) -> Mapping[int, StateMap[str]]:
  78. state_filter = state_filter or StateFilter.all()
  79. results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
  80. if isinstance(self.database_engine, PostgresEngine):
  81. # Temporarily disable sequential scans in this transaction. This is
  82. # a temporary hack until we can add the right indices in
  83. txn.execute("SET LOCAL enable_seqscan=off")
  84. # The below query walks the state_group tree so that the "state"
  85. # table includes all state_groups in the tree. It then joins
  86. # against `state_groups_state` to fetch the latest state.
  87. # It assumes that previous state groups are always numerically
  88. # lesser.
  89. # This may return multiple rows per (type, state_key), but last_value
  90. # should be the same.
  91. sql = """
  92. WITH RECURSIVE sgs(state_group) AS (
  93. VALUES(?::bigint)
  94. UNION ALL
  95. SELECT prev_state_group FROM state_group_edges e, sgs s
  96. WHERE s.state_group = e.state_group
  97. )
  98. %s
  99. """
  100. overall_select_query_args: List[Union[int, str]] = []
  101. # This is an optimization to create a select clause per-condition. This
  102. # makes the query planner a lot smarter on what rows should pull out in the
  103. # first place and we end up with something that takes 10x less time to get a
  104. # result.
  105. use_condition_optimization = (
  106. not state_filter.include_others and not state_filter.is_full()
  107. )
  108. state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
  109. # We don't need to caclculate this list if we're not using the condition
  110. # optimization
  111. if use_condition_optimization:
  112. for etype, state_keys in state_filter.types.items():
  113. if state_keys is None:
  114. state_filter_condition_combos.append((etype, None))
  115. else:
  116. for state_key in state_keys:
  117. state_filter_condition_combos.append((etype, state_key))
  118. # And here is the optimization itself. We don't want to do the optimization
  119. # if there are too many individual conditions. 10 is an arbitrary number
  120. # with no testing behind it but we do know that we specifically made this
  121. # optimization for when we grab the necessary state out for
  122. # `filter_events_for_client` which just uses 2 conditions
  123. # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
  124. if use_condition_optimization and len(state_filter_condition_combos) < 10:
  125. select_clause_list: List[str] = []
  126. for etype, skey in state_filter_condition_combos:
  127. if skey is None:
  128. where_clause = "(type = ?)"
  129. overall_select_query_args.extend([etype])
  130. else:
  131. where_clause = "(type = ? AND state_key = ?)"
  132. overall_select_query_args.extend([etype, skey])
  133. select_clause_list.append(
  134. f"""
  135. (
  136. SELECT DISTINCT ON (type, state_key)
  137. type, state_key, event_id
  138. FROM state_groups_state
  139. INNER JOIN sgs USING (state_group)
  140. WHERE {where_clause}
  141. ORDER BY type, state_key, state_group DESC
  142. )
  143. """
  144. )
  145. overall_select_clause = " UNION ".join(select_clause_list)
  146. else:
  147. where_clause, where_args = state_filter.make_sql_filter_clause()
  148. # Unless the filter clause is empty, we're going to append it after an
  149. # existing where clause
  150. if where_clause:
  151. where_clause = " AND (%s)" % (where_clause,)
  152. overall_select_query_args.extend(where_args)
  153. overall_select_clause = f"""
  154. SELECT DISTINCT ON (type, state_key)
  155. type, state_key, event_id
  156. FROM state_groups_state
  157. WHERE state_group IN (
  158. SELECT state_group FROM sgs
  159. ) {where_clause}
  160. ORDER BY type, state_key, state_group DESC
  161. """
  162. for group in groups:
  163. args: List[Union[int, str]] = [group]
  164. args.extend(overall_select_query_args)
  165. txn.execute(sql % (overall_select_clause,), args)
  166. for row in txn:
  167. typ, state_key, event_id = row
  168. key = (intern_string(typ), intern_string(state_key))
  169. results[group][key] = event_id
  170. else:
  171. max_entries_returned = state_filter.max_entries_returned()
  172. where_clause, where_args = state_filter.make_sql_filter_clause()
  173. # Unless the filter clause is empty, we're going to append it after an
  174. # existing where clause
  175. if where_clause:
  176. where_clause = " AND (%s)" % (where_clause,)
  177. # We don't use WITH RECURSIVE on sqlite3 as there are distributions
  178. # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
  179. for group in groups:
  180. next_group: Optional[int] = group
  181. while next_group:
  182. # We did this before by getting the list of group ids, and
  183. # then passing that list to sqlite to get latest event for
  184. # each (type, state_key). However, that was terribly slow
  185. # without the right indices (which we can't add until
  186. # after we finish deduping state, which requires this func)
  187. args = [next_group]
  188. args.extend(where_args)
  189. txn.execute(
  190. "SELECT type, state_key, event_id FROM state_groups_state"
  191. " WHERE state_group = ? " + where_clause,
  192. args,
  193. )
  194. results[group].update(
  195. ((typ, state_key), event_id)
  196. for typ, state_key, event_id in txn
  197. if (typ, state_key) not in results[group]
  198. )
  199. # If the number of entries in the (type,state_key)->event_id dict
  200. # matches the number of (type,state_keys) types we were searching
  201. # for, then we must have found them all, so no need to go walk
  202. # further down the tree... UNLESS our types filter contained
  203. # wildcards (i.e. Nones) in which case we have to do an exhaustive
  204. # search
  205. if (
  206. max_entries_returned is not None
  207. and len(results[group]) == max_entries_returned
  208. ):
  209. break
  210. next_group = self.db_pool.simple_select_one_onecol_txn(
  211. txn,
  212. table="state_group_edges",
  213. keyvalues={"state_group": next_group},
  214. retcol="prev_state_group",
  215. allow_none=True,
  216. )
  217. # The results shouldn't be considered mutable.
  218. return results
  219. class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
  220. STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
  221. STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
  222. STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
  223. STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx"
  224. def __init__(
  225. self,
  226. database: DatabasePool,
  227. db_conn: LoggingDatabaseConnection,
  228. hs: "HomeServer",
  229. ):
  230. super().__init__(database, db_conn, hs)
  231. self.db_pool.updates.register_background_update_handler(
  232. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
  233. self._background_deduplicate_state,
  234. )
  235. self.db_pool.updates.register_background_update_handler(
  236. self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
  237. )
  238. self.db_pool.updates.register_background_index_update(
  239. self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
  240. index_name="state_groups_room_id_idx",
  241. table="state_groups",
  242. columns=["room_id"],
  243. )
  244. # `state_group_edges` can cause severe performance issues if duplicate
  245. # rows are introduced, which can accidentally be done by well-meaning
  246. # server admins when trying to restore a database dump, etc.
  247. # See https://github.com/matrix-org/synapse/issues/11779.
  248. # Introduce a unique index to guard against that.
  249. self.db_pool.updates.register_background_index_update(
  250. self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME,
  251. index_name="state_group_edges_unique_idx",
  252. table="state_group_edges",
  253. columns=["state_group", "prev_state_group"],
  254. unique=True,
  255. # The old index was on (state_group) and was not unique.
  256. replaces_index="state_group_edges_idx",
  257. )
  258. async def _background_deduplicate_state(
  259. self, progress: dict, batch_size: int
  260. ) -> int:
  261. """This background update will slowly deduplicate state by reencoding
  262. them as deltas.
  263. """
  264. last_state_group = progress.get("last_state_group", 0)
  265. rows_inserted = progress.get("rows_inserted", 0)
  266. max_group = progress.get("max_group", None)
  267. BATCH_SIZE_SCALE_FACTOR = 100
  268. batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
  269. if max_group is None:
  270. rows = await self.db_pool.execute(
  271. "_background_deduplicate_state",
  272. None,
  273. "SELECT coalesce(max(id), 0) FROM state_groups",
  274. )
  275. max_group = rows[0][0]
  276. def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
  277. new_last_state_group = last_state_group
  278. for count in range(batch_size):
  279. txn.execute(
  280. "SELECT id, room_id FROM state_groups"
  281. " WHERE ? < id AND id <= ?"
  282. " ORDER BY id ASC"
  283. " LIMIT 1",
  284. (new_last_state_group, max_group),
  285. )
  286. row = txn.fetchone()
  287. if row:
  288. state_group, room_id = row
  289. if not row or not state_group:
  290. return True, count
  291. txn.execute(
  292. "SELECT state_group FROM state_group_edges"
  293. " WHERE state_group = ?",
  294. (state_group,),
  295. )
  296. # If we reach a point where we've already started inserting
  297. # edges we should stop.
  298. if txn.fetchall():
  299. return True, count
  300. txn.execute(
  301. "SELECT coalesce(max(id), 0) FROM state_groups"
  302. " WHERE id < ? AND room_id = ?",
  303. (state_group, room_id),
  304. )
  305. # There will be a result due to the coalesce.
  306. (prev_group,) = txn.fetchone() # type: ignore
  307. new_last_state_group = state_group
  308. if prev_group:
  309. potential_hops = self._count_state_group_hops_txn(txn, prev_group)
  310. if potential_hops >= MAX_STATE_DELTA_HOPS:
  311. # We want to ensure chains are at most this long,#
  312. # otherwise read performance degrades.
  313. continue
  314. prev_state_by_group = self._get_state_groups_from_groups_txn(
  315. txn, [prev_group]
  316. )
  317. prev_state = prev_state_by_group[prev_group]
  318. curr_state_by_group = self._get_state_groups_from_groups_txn(
  319. txn, [state_group]
  320. )
  321. curr_state = curr_state_by_group[state_group]
  322. if not set(prev_state.keys()) - set(curr_state.keys()):
  323. # We can only do a delta if the current has a strict super set
  324. # of keys
  325. delta_state = {
  326. key: value
  327. for key, value in curr_state.items()
  328. if prev_state.get(key, None) != value
  329. }
  330. self.db_pool.simple_delete_txn(
  331. txn,
  332. table="state_group_edges",
  333. keyvalues={"state_group": state_group},
  334. )
  335. self.db_pool.simple_insert_txn(
  336. txn,
  337. table="state_group_edges",
  338. values={
  339. "state_group": state_group,
  340. "prev_state_group": prev_group,
  341. },
  342. )
  343. self.db_pool.simple_delete_txn(
  344. txn,
  345. table="state_groups_state",
  346. keyvalues={"state_group": state_group},
  347. )
  348. self.db_pool.simple_insert_many_txn(
  349. txn,
  350. table="state_groups_state",
  351. keys=(
  352. "state_group",
  353. "room_id",
  354. "type",
  355. "state_key",
  356. "event_id",
  357. ),
  358. values=[
  359. (state_group, room_id, key[0], key[1], state_id)
  360. for key, state_id in delta_state.items()
  361. ],
  362. )
  363. progress = {
  364. "last_state_group": state_group,
  365. "rows_inserted": rows_inserted + batch_size,
  366. "max_group": max_group,
  367. }
  368. self.db_pool.updates._background_update_progress_txn(
  369. txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
  370. )
  371. return False, batch_size
  372. finished, result = await self.db_pool.runInteraction(
  373. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
  374. )
  375. if finished:
  376. await self.db_pool.updates._end_background_update(
  377. self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
  378. )
  379. return result * BATCH_SIZE_SCALE_FACTOR
  380. async def _background_index_state(self, progress: dict, batch_size: int) -> int:
  381. def reindex_txn(conn: LoggingDatabaseConnection) -> None:
  382. conn.rollback()
  383. if isinstance(self.database_engine, PostgresEngine):
  384. # postgres insists on autocommit for the index
  385. conn.set_session(autocommit=True)
  386. try:
  387. txn = conn.cursor()
  388. txn.execute(
  389. "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
  390. " ON state_groups_state(state_group, type, state_key)"
  391. )
  392. txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
  393. finally:
  394. conn.set_session(autocommit=False)
  395. else:
  396. txn = conn.cursor()
  397. txn.execute(
  398. "CREATE INDEX state_groups_state_type_idx"
  399. " ON state_groups_state(state_group, type, state_key)"
  400. )
  401. txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
  402. await self.db_pool.runWithConnection(reindex_txn)
  403. await self.db_pool.updates._end_background_update(
  404. self.STATE_GROUP_INDEX_UPDATE_NAME
  405. )
  406. return 1