Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 
 
 
 

846 wiersze
31 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import (
  16. TYPE_CHECKING,
  17. Collection,
  18. Dict,
  19. Iterable,
  20. Optional,
  21. Sequence,
  22. Set,
  23. Tuple,
  24. )
  25. import attr
  26. from twisted.internet import defer
  27. from synapse.api.constants import EventTypes
  28. from synapse.logging.context import make_deferred_yieldable, run_in_background
  29. from synapse.storage._base import SQLBaseStore
  30. from synapse.storage.database import (
  31. DatabasePool,
  32. LoggingDatabaseConnection,
  33. LoggingTransaction,
  34. )
  35. from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
  36. from synapse.storage.state import StateFilter
  37. from synapse.storage.types import Cursor
  38. from synapse.storage.util.sequence import build_sequence_generator
  39. from synapse.types import MutableStateMap, StateKey, StateMap
  40. from synapse.util import unwrapFirstError
  41. from synapse.util.async_helpers import (
  42. AbstractObservableDeferred,
  43. ObservableDeferred,
  44. yieldable_gather_results,
  45. )
  46. from synapse.util.caches.descriptors import cached
  47. from synapse.util.caches.dictionary_cache import DictionaryCache
  48. if TYPE_CHECKING:
  49. from synapse.server import HomeServer
  50. logger = logging.getLogger(__name__)
  51. MAX_STATE_DELTA_HOPS = 100
  52. MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
  53. @attr.s(slots=True, frozen=True, auto_attribs=True)
  54. class _GetStateGroupDelta:
  55. """Return type of get_state_group_delta that implements __len__, which lets
  56. us use the iterable flag when caching
  57. """
  58. prev_group: Optional[int]
  59. delta_ids: Optional[StateMap[str]]
  60. def __len__(self) -> int:
  61. return len(self.delta_ids) if self.delta_ids else 0
  62. class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
  63. """A data store for fetching/storing state groups."""
  64. def __init__(
  65. self,
  66. database: DatabasePool,
  67. db_conn: LoggingDatabaseConnection,
  68. hs: "HomeServer",
  69. ):
  70. super().__init__(database, db_conn, hs)
  71. # Originally the state store used a single DictionaryCache to cache the
  72. # event IDs for the state types in a given state group to avoid hammering
  73. # on the state_group* tables.
  74. #
  75. # The point of using a DictionaryCache is that it can cache a subset
  76. # of the state events for a given state group (i.e. a subset of the keys for a
  77. # given dict which is an entry in the cache for a given state group ID).
  78. #
  79. # However, this poses problems when performing complicated queries
  80. # on the store - for instance: "give me all the state for this group, but
  81. # limit members to this subset of users", as DictionaryCache's API isn't
  82. # rich enough to say "please cache any of these fields, apart from this subset".
  83. # This is problematic when lazy loading members, which requires this behaviour,
  84. # as without it the cache has no choice but to speculatively load all
  85. # state events for the group, which negates the efficiency being sought.
  86. #
  87. # Rather than overcomplicating DictionaryCache's API, we instead split the
  88. # state_group_cache into two halves - one for tracking non-member events,
  89. # and the other for tracking member_events. This means that lazy loading
  90. # queries can be made in a cache-friendly manner by querying both caches
  91. # separately and then merging the result. So for the example above, you
  92. # would query the members cache for a specific subset of state keys
  93. # (which DictionaryCache will handle efficiently and fine) and the non-members
  94. # cache for all state (which DictionaryCache will similarly handle fine)
  95. # and then just merge the results together.
  96. #
  97. # We size the non-members cache to be smaller than the members cache as the
  98. # vast majority of state in Matrix (today) is member events.
  99. self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
  100. "*stateGroupCache*",
  101. # TODO: this hasn't been tuned yet
  102. 50000,
  103. )
  104. self._state_group_members_cache: DictionaryCache[
  105. int, StateKey, str
  106. ] = DictionaryCache(
  107. "*stateGroupMembersCache*",
  108. 500000,
  109. )
  110. # Current ongoing get_state_for_groups in-flight requests
  111. # {group ID -> {StateFilter -> ObservableDeferred}}
  112. self._state_group_inflight_requests: Dict[
  113. int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
  114. ] = {}
  115. def get_max_state_group_txn(txn: Cursor) -> int:
  116. txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
  117. return txn.fetchone()[0] # type: ignore
  118. self._state_group_seq_gen = build_sequence_generator(
  119. db_conn,
  120. self.database_engine,
  121. get_max_state_group_txn,
  122. "state_group_id_seq",
  123. table="state_groups",
  124. id_column="id",
  125. )
  126. @cached(max_entries=10000, iterable=True)
  127. async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
  128. """Given a state group try to return a previous group and a delta between
  129. the old and the new.
  130. Returns:
  131. _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
  132. """
  133. def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
  134. prev_group = self.db_pool.simple_select_one_onecol_txn(
  135. txn,
  136. table="state_group_edges",
  137. keyvalues={"state_group": state_group},
  138. retcol="prev_state_group",
  139. allow_none=True,
  140. )
  141. if not prev_group:
  142. return _GetStateGroupDelta(None, None)
  143. delta_ids = self.db_pool.simple_select_list_txn(
  144. txn,
  145. table="state_groups_state",
  146. keyvalues={"state_group": state_group},
  147. retcols=("type", "state_key", "event_id"),
  148. )
  149. return _GetStateGroupDelta(
  150. prev_group,
  151. {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
  152. )
  153. return await self.db_pool.runInteraction(
  154. "get_state_group_delta", _get_state_group_delta_txn
  155. )
  156. async def _get_state_groups_from_groups(
  157. self, groups: Sequence[int], state_filter: StateFilter
  158. ) -> Dict[int, StateMap[str]]:
  159. """Returns the state groups for a given set of groups from the
  160. database, filtering on types of state events.
  161. Args:
  162. groups: list of state group IDs to query
  163. state_filter: The state filter used to fetch state
  164. from the database.
  165. Returns:
  166. Dict of state group to state map.
  167. """
  168. results: Dict[int, StateMap[str]] = {}
  169. chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
  170. for chunk in chunks:
  171. res = await self.db_pool.runInteraction(
  172. "_get_state_groups_from_groups",
  173. self._get_state_groups_from_groups_txn,
  174. chunk,
  175. state_filter,
  176. )
  177. results.update(res)
  178. return results
  179. def _get_state_for_group_using_cache(
  180. self,
  181. cache: DictionaryCache[int, StateKey, str],
  182. group: int,
  183. state_filter: StateFilter,
  184. ) -> Tuple[MutableStateMap[str], bool]:
  185. """Checks if group is in cache. See `_get_state_for_groups`
  186. Args:
  187. cache: the state group cache to use
  188. group: The state group to lookup
  189. state_filter: The state filter used to fetch state from the database.
  190. Returns:
  191. 2-tuple (`state_dict`, `got_all`).
  192. `got_all` is a bool indicating if we successfully retrieved all
  193. requests state from the cache, if False we need to query the DB for the
  194. missing state.
  195. """
  196. cache_entry = cache.get(group)
  197. state_dict_ids = cache_entry.value
  198. if cache_entry.full or state_filter.is_full():
  199. # Either we have everything or want everything, either way
  200. # `is_all` tells us whether we've gotten everything.
  201. return state_filter.filter_state(state_dict_ids), cache_entry.full
  202. # tracks whether any of our requested types are missing from the cache
  203. missing_types = False
  204. if state_filter.has_wildcards():
  205. # We don't know if we fetched all the state keys for the types in
  206. # the filter that are wildcards, so we have to assume that we may
  207. # have missed some.
  208. missing_types = True
  209. else:
  210. # There aren't any wild cards, so `concrete_types()` returns the
  211. # complete list of event types we're wanting.
  212. for key in state_filter.concrete_types():
  213. if key not in state_dict_ids and key not in cache_entry.known_absent:
  214. missing_types = True
  215. break
  216. return state_filter.filter_state(state_dict_ids), not missing_types
  217. def _get_state_for_group_gather_inflight_requests(
  218. self, group: int, state_filter_left_over: StateFilter
  219. ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
  220. """
  221. Attempts to gather in-flight requests and re-use them to retrieve state
  222. for the given state group, filtered with the given state filter.
  223. If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
  224. and there *still* isn't enough information to complete the request by solely
  225. reusing others, a full state filter will be requested to ensure that subsequent
  226. requests can reuse this request.
  227. Used as part of _get_state_for_group_using_inflight_cache.
  228. Returns:
  229. Tuple of two values:
  230. A sequence of ObservableDeferreds to observe
  231. A StateFilter representing what else needs to be requested to fulfill the request
  232. """
  233. inflight_requests = self._state_group_inflight_requests.get(group)
  234. if inflight_requests is None:
  235. # no requests for this group, need to retrieve it all ourselves
  236. return (), state_filter_left_over
  237. # The list of ongoing requests which will help narrow the current request.
  238. reusable_requests = []
  239. for (request_state_filter, request_deferred) in inflight_requests.items():
  240. new_state_filter_left_over = state_filter_left_over.approx_difference(
  241. request_state_filter
  242. )
  243. if new_state_filter_left_over == state_filter_left_over:
  244. # Reusing this request would not gain us anything, so don't bother.
  245. continue
  246. reusable_requests.append(request_deferred)
  247. state_filter_left_over = new_state_filter_left_over
  248. if state_filter_left_over == StateFilter.none():
  249. # we have managed to collect enough of the in-flight requests
  250. # to cover our StateFilter and give us the state we need.
  251. break
  252. if (
  253. state_filter_left_over != StateFilter.none()
  254. and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
  255. ):
  256. # There are too many requests for this group.
  257. # To prevent even more from building up, we request the whole
  258. # state filter to guarantee that we can be reused by any subsequent
  259. # requests for this state group.
  260. return (), StateFilter.all()
  261. return reusable_requests, state_filter_left_over
  262. async def _get_state_for_group_fire_request(
  263. self, group: int, state_filter: StateFilter
  264. ) -> StateMap[str]:
  265. """
  266. Fires off a request to get the state at a state group,
  267. potentially filtering by type and/or state key.
  268. This request will be tracked in the in-flight request cache and automatically
  269. removed when it is finished.
  270. Used as part of _get_state_for_group_using_inflight_cache.
  271. Args:
  272. group: ID of the state group for which we want to get state
  273. state_filter: the state filter used to fetch state from the database
  274. """
  275. cache_sequence_nm = self._state_group_cache.sequence
  276. cache_sequence_m = self._state_group_members_cache.sequence
  277. # Help the cache hit ratio by expanding the filter a bit
  278. db_state_filter = state_filter.return_expanded()
  279. async def _the_request() -> StateMap[str]:
  280. group_to_state_dict = await self._get_state_groups_from_groups(
  281. (group,), state_filter=db_state_filter
  282. )
  283. # Now let's update the caches
  284. self._insert_into_cache(
  285. group_to_state_dict,
  286. db_state_filter,
  287. cache_seq_num_members=cache_sequence_m,
  288. cache_seq_num_non_members=cache_sequence_nm,
  289. )
  290. # Remove ourselves from the in-flight cache
  291. group_request_dict = self._state_group_inflight_requests[group]
  292. del group_request_dict[db_state_filter]
  293. if not group_request_dict:
  294. # If there are no more requests in-flight for this group,
  295. # clean up the cache by removing the empty dictionary
  296. del self._state_group_inflight_requests[group]
  297. return group_to_state_dict[group]
  298. # We don't immediately await the result, so must use run_in_background
  299. # But we DO await the result before the current log context (request)
  300. # finishes, so don't need to run it as a background process.
  301. request_deferred = run_in_background(_the_request)
  302. observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
  303. # Insert the ObservableDeferred into the cache
  304. group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
  305. group_request_dict[db_state_filter] = observable_deferred
  306. return await make_deferred_yieldable(observable_deferred.observe())
  307. async def _get_state_for_group_using_inflight_cache(
  308. self, group: int, state_filter: StateFilter
  309. ) -> MutableStateMap[str]:
  310. """
  311. Gets the state at a state group, potentially filtering by type and/or
  312. state key.
  313. 1. Calls _get_state_for_group_gather_inflight_requests to gather any
  314. ongoing requests which might overlap with the current request.
  315. 2. Fires a new request, using _get_state_for_group_fire_request,
  316. for any state which cannot be gathered from ongoing requests.
  317. Args:
  318. group: ID of the state group for which we want to get state
  319. state_filter: the state filter used to fetch state from the database
  320. Returns:
  321. state map
  322. """
  323. # first, figure out whether we can re-use any in-flight requests
  324. # (and if so, what would be left over)
  325. (
  326. reusable_requests,
  327. state_filter_left_over,
  328. ) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
  329. if state_filter_left_over != StateFilter.none():
  330. # Fetch remaining state
  331. remaining = await self._get_state_for_group_fire_request(
  332. group, state_filter_left_over
  333. )
  334. assembled_state: MutableStateMap[str] = dict(remaining)
  335. else:
  336. assembled_state = {}
  337. gathered = await make_deferred_yieldable(
  338. defer.gatherResults(
  339. (r.observe() for r in reusable_requests), consumeErrors=True
  340. )
  341. ).addErrback(unwrapFirstError)
  342. # assemble our result.
  343. for result_piece in gathered:
  344. assembled_state.update(result_piece)
  345. # Filter out any state that may be more than what we asked for.
  346. return state_filter.filter_state(assembled_state)
  347. async def _get_state_for_groups(
  348. self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
  349. ) -> Dict[int, MutableStateMap[str]]:
  350. """Gets the state at each of a list of state groups, optionally
  351. filtering by type/state_key
  352. Args:
  353. groups: list of state groups for which we want
  354. to get the state.
  355. state_filter: The state filter used to fetch state
  356. from the database.
  357. Returns:
  358. Dict of state group to state map.
  359. """
  360. state_filter = state_filter or StateFilter.all()
  361. member_filter, non_member_filter = state_filter.get_member_split()
  362. # Now we look them up in the member and non-member caches
  363. (
  364. non_member_state,
  365. incomplete_groups_nm,
  366. ) = self._get_state_for_groups_using_cache(
  367. groups, self._state_group_cache, state_filter=non_member_filter
  368. )
  369. (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
  370. groups, self._state_group_members_cache, state_filter=member_filter
  371. )
  372. state = dict(non_member_state)
  373. for group in groups:
  374. state[group].update(member_state[group])
  375. # Now fetch any missing groups from the database
  376. incomplete_groups = incomplete_groups_m | incomplete_groups_nm
  377. if not incomplete_groups:
  378. return state
  379. async def get_from_cache(group: int, state_filter: StateFilter) -> None:
  380. state[group] = await self._get_state_for_group_using_inflight_cache(
  381. group, state_filter
  382. )
  383. await yieldable_gather_results(
  384. get_from_cache,
  385. incomplete_groups,
  386. state_filter,
  387. )
  388. return state
  389. def _get_state_for_groups_using_cache(
  390. self,
  391. groups: Iterable[int],
  392. cache: DictionaryCache[int, StateKey, str],
  393. state_filter: StateFilter,
  394. ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
  395. """Gets the state at each of a list of state groups, optionally
  396. filtering by type/state_key, querying from a specific cache.
  397. Args:
  398. groups: list of state groups for which we want to get the state.
  399. cache: the cache of group ids to state dicts which
  400. we will pass through - either the normal state cache or the
  401. specific members state cache.
  402. state_filter: The state filter used to fetch state from the
  403. database.
  404. Returns:
  405. Tuple of dict of state_group_id to state map of entries in the
  406. cache, and the state group ids either missing from the cache or
  407. incomplete.
  408. """
  409. results = {}
  410. incomplete_groups = set()
  411. for group in set(groups):
  412. state_dict_ids, got_all = self._get_state_for_group_using_cache(
  413. cache, group, state_filter
  414. )
  415. results[group] = state_dict_ids
  416. if not got_all:
  417. incomplete_groups.add(group)
  418. return results, incomplete_groups
  419. def _insert_into_cache(
  420. self,
  421. group_to_state_dict: Dict[int, StateMap[str]],
  422. state_filter: StateFilter,
  423. cache_seq_num_members: int,
  424. cache_seq_num_non_members: int,
  425. ) -> None:
  426. """Inserts results from querying the database into the relevant cache.
  427. Args:
  428. group_to_state_dict: The new entries pulled from database.
  429. Map from state group to state dict
  430. state_filter: The state filter used to fetch state
  431. from the database.
  432. cache_seq_num_members: Sequence number of member cache since
  433. last lookup in cache
  434. cache_seq_num_non_members: Sequence number of member cache since
  435. last lookup in cache
  436. """
  437. # We need to work out which types we've fetched from the DB for the
  438. # member vs non-member caches. This should be as accurate as possible,
  439. # but can be an underestimate (e.g. when we have wild cards)
  440. member_filter, non_member_filter = state_filter.get_member_split()
  441. if member_filter.is_full():
  442. # We fetched all member events
  443. member_types = None
  444. else:
  445. # `concrete_types()` will only return a subset when there are wild
  446. # cards in the filter, but that's fine.
  447. member_types = member_filter.concrete_types()
  448. if non_member_filter.is_full():
  449. # We fetched all non member events
  450. non_member_types = None
  451. else:
  452. non_member_types = non_member_filter.concrete_types()
  453. for group, group_state_dict in group_to_state_dict.items():
  454. state_dict_members = {}
  455. state_dict_non_members = {}
  456. for k, v in group_state_dict.items():
  457. if k[0] == EventTypes.Member:
  458. state_dict_members[k] = v
  459. else:
  460. state_dict_non_members[k] = v
  461. self._state_group_members_cache.update(
  462. cache_seq_num_members,
  463. key=group,
  464. value=state_dict_members,
  465. fetched_keys=member_types,
  466. )
  467. self._state_group_cache.update(
  468. cache_seq_num_non_members,
  469. key=group,
  470. value=state_dict_non_members,
  471. fetched_keys=non_member_types,
  472. )
  473. async def store_state_group(
  474. self,
  475. event_id: str,
  476. room_id: str,
  477. prev_group: Optional[int],
  478. delta_ids: Optional[StateMap[str]],
  479. current_state_ids: StateMap[str],
  480. ) -> int:
  481. """Store a new set of state, returning a newly assigned state group.
  482. Args:
  483. event_id: The event ID for which the state was calculated
  484. room_id
  485. prev_group: A previous state group for the room, optional.
  486. delta_ids: The delta between state at `prev_group` and
  487. `current_state_ids`, if `prev_group` was given. Same format as
  488. `current_state_ids`.
  489. current_state_ids: The state to store. Map of (type, state_key)
  490. to event_id.
  491. Returns:
  492. The state group ID
  493. """
  494. def _store_state_group_txn(txn: LoggingTransaction) -> int:
  495. if current_state_ids is None:
  496. # AFAIK, this can never happen
  497. raise Exception("current_state_ids cannot be None")
  498. state_group = self._state_group_seq_gen.get_next_id_txn(txn)
  499. self.db_pool.simple_insert_txn(
  500. txn,
  501. table="state_groups",
  502. values={"id": state_group, "room_id": room_id, "event_id": event_id},
  503. )
  504. # We persist as a delta if we can, while also ensuring the chain
  505. # of deltas isn't tooo long, as otherwise read performance degrades.
  506. if prev_group:
  507. is_in_db = self.db_pool.simple_select_one_onecol_txn(
  508. txn,
  509. table="state_groups",
  510. keyvalues={"id": prev_group},
  511. retcol="id",
  512. allow_none=True,
  513. )
  514. if not is_in_db:
  515. raise Exception(
  516. "Trying to persist state with unpersisted prev_group: %r"
  517. % (prev_group,)
  518. )
  519. potential_hops = self._count_state_group_hops_txn(txn, prev_group)
  520. if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
  521. assert delta_ids is not None
  522. self.db_pool.simple_insert_txn(
  523. txn,
  524. table="state_group_edges",
  525. values={"state_group": state_group, "prev_state_group": prev_group},
  526. )
  527. self.db_pool.simple_insert_many_txn(
  528. txn,
  529. table="state_groups_state",
  530. keys=("state_group", "room_id", "type", "state_key", "event_id"),
  531. values=[
  532. (state_group, room_id, key[0], key[1], state_id)
  533. for key, state_id in delta_ids.items()
  534. ],
  535. )
  536. else:
  537. self.db_pool.simple_insert_many_txn(
  538. txn,
  539. table="state_groups_state",
  540. keys=("state_group", "room_id", "type", "state_key", "event_id"),
  541. values=[
  542. (state_group, room_id, key[0], key[1], state_id)
  543. for key, state_id in current_state_ids.items()
  544. ],
  545. )
  546. # Prefill the state group caches with this group.
  547. # It's fine to use the sequence like this as the state group map
  548. # is immutable. (If the map wasn't immutable then this prefill could
  549. # race with another update)
  550. current_member_state_ids = {
  551. s: ev
  552. for (s, ev) in current_state_ids.items()
  553. if s[0] == EventTypes.Member
  554. }
  555. txn.call_after(
  556. self._state_group_members_cache.update,
  557. self._state_group_members_cache.sequence,
  558. key=state_group,
  559. value=dict(current_member_state_ids),
  560. )
  561. current_non_member_state_ids = {
  562. s: ev
  563. for (s, ev) in current_state_ids.items()
  564. if s[0] != EventTypes.Member
  565. }
  566. txn.call_after(
  567. self._state_group_cache.update,
  568. self._state_group_cache.sequence,
  569. key=state_group,
  570. value=dict(current_non_member_state_ids),
  571. )
  572. return state_group
  573. return await self.db_pool.runInteraction(
  574. "store_state_group", _store_state_group_txn
  575. )
  576. async def purge_unreferenced_state_groups(
  577. self, room_id: str, state_groups_to_delete: Collection[int]
  578. ) -> None:
  579. """Deletes no longer referenced state groups and de-deltas any state
  580. groups that reference them.
  581. Args:
  582. room_id: The room the state groups belong to (must all be in the
  583. same room).
  584. state_groups_to_delete: Set of all state groups to delete.
  585. """
  586. await self.db_pool.runInteraction(
  587. "purge_unreferenced_state_groups",
  588. self._purge_unreferenced_state_groups,
  589. room_id,
  590. state_groups_to_delete,
  591. )
  592. def _purge_unreferenced_state_groups(
  593. self,
  594. txn: LoggingTransaction,
  595. room_id: str,
  596. state_groups_to_delete: Collection[int],
  597. ) -> None:
  598. logger.info(
  599. "[purge] found %i state groups to delete", len(state_groups_to_delete)
  600. )
  601. rows = self.db_pool.simple_select_many_txn(
  602. txn,
  603. table="state_group_edges",
  604. column="prev_state_group",
  605. iterable=state_groups_to_delete,
  606. keyvalues={},
  607. retcols=("state_group",),
  608. )
  609. remaining_state_groups = {
  610. row["state_group"]
  611. for row in rows
  612. if row["state_group"] not in state_groups_to_delete
  613. }
  614. logger.info(
  615. "[purge] de-delta-ing %i remaining state groups",
  616. len(remaining_state_groups),
  617. )
  618. # Now we turn the state groups that reference to-be-deleted state
  619. # groups to non delta versions.
  620. for sg in remaining_state_groups:
  621. logger.info("[purge] de-delta-ing remaining state group %s", sg)
  622. curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
  623. curr_state = curr_state_by_group[sg]
  624. self.db_pool.simple_delete_txn(
  625. txn, table="state_groups_state", keyvalues={"state_group": sg}
  626. )
  627. self.db_pool.simple_delete_txn(
  628. txn, table="state_group_edges", keyvalues={"state_group": sg}
  629. )
  630. self.db_pool.simple_insert_many_txn(
  631. txn,
  632. table="state_groups_state",
  633. keys=("state_group", "room_id", "type", "state_key", "event_id"),
  634. values=[
  635. (sg, room_id, key[0], key[1], state_id)
  636. for key, state_id in curr_state.items()
  637. ],
  638. )
  639. logger.info("[purge] removing redundant state groups")
  640. txn.execute_batch(
  641. "DELETE FROM state_groups_state WHERE state_group = ?",
  642. ((sg,) for sg in state_groups_to_delete),
  643. )
  644. txn.execute_batch(
  645. "DELETE FROM state_groups WHERE id = ?",
  646. ((sg,) for sg in state_groups_to_delete),
  647. )
  648. async def get_previous_state_groups(
  649. self, state_groups: Iterable[int]
  650. ) -> Dict[int, int]:
  651. """Fetch the previous groups of the given state groups.
  652. Args:
  653. state_groups
  654. Returns:
  655. A mapping from state group to previous state group.
  656. """
  657. rows = await self.db_pool.simple_select_many_batch(
  658. table="state_group_edges",
  659. column="prev_state_group",
  660. iterable=state_groups,
  661. keyvalues={},
  662. retcols=("prev_state_group", "state_group"),
  663. desc="get_previous_state_groups",
  664. )
  665. return {row["state_group"]: row["prev_state_group"] for row in rows}
  666. async def purge_room_state(
  667. self, room_id: str, state_groups_to_delete: Collection[int]
  668. ) -> None:
  669. """Deletes all record of a room from state tables
  670. Args:
  671. room_id:
  672. state_groups_to_delete: State groups to delete
  673. """
  674. await self.db_pool.runInteraction(
  675. "purge_room_state",
  676. self._purge_room_state_txn,
  677. room_id,
  678. state_groups_to_delete,
  679. )
  680. def _purge_room_state_txn(
  681. self,
  682. txn: LoggingTransaction,
  683. room_id: str,
  684. state_groups_to_delete: Collection[int],
  685. ) -> None:
  686. # first we have to delete the state groups states
  687. logger.info("[purge] removing %s from state_groups_state", room_id)
  688. self.db_pool.simple_delete_many_txn(
  689. txn,
  690. table="state_groups_state",
  691. column="state_group",
  692. values=state_groups_to_delete,
  693. keyvalues={},
  694. )
  695. # ... and the state group edges
  696. logger.info("[purge] removing %s from state_group_edges", room_id)
  697. self.db_pool.simple_delete_many_txn(
  698. txn,
  699. table="state_group_edges",
  700. column="state_group",
  701. values=state_groups_to_delete,
  702. keyvalues={},
  703. )
  704. # ... and the state groups
  705. logger.info("[purge] removing %s from state_groups", room_id)
  706. self.db_pool.simple_delete_many_txn(
  707. txn,
  708. table="state_groups",
  709. column="id",
  710. values=state_groups_to_delete,
  711. keyvalues={},
  712. )