選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 
 
 

845 行
27 KiB

  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import heapq
  15. import itertools
  16. import logging
  17. from typing import (
  18. Any,
  19. Awaitable,
  20. Callable,
  21. Dict,
  22. Generator,
  23. Iterable,
  24. List,
  25. Optional,
  26. Sequence,
  27. Set,
  28. Tuple,
  29. overload,
  30. )
  31. from typing_extensions import Literal, Protocol
  32. from synapse import event_auth
  33. from synapse.api.constants import EventTypes
  34. from synapse.api.errors import AuthError
  35. from synapse.api.room_versions import RoomVersion
  36. from synapse.events import EventBase
  37. from synapse.types import MutableStateMap, StateMap, StrCollection
  38. logger = logging.getLogger(__name__)
  39. class Clock(Protocol):
  40. # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
  41. # We only ever sleep(0) though, so that other async functions can make forward
  42. # progress without waiting for stateres to complete.
  43. def sleep(self, duration_ms: float) -> Awaitable[None]:
  44. ...
  45. class StateResolutionStore(Protocol):
  46. # This is usually synapse.state.StateResolutionStore, but it's replaced with a
  47. # TestStateResolutionStore in tests.
  48. def get_events(
  49. self, event_ids: StrCollection, allow_rejected: bool = False
  50. ) -> Awaitable[Dict[str, EventBase]]:
  51. ...
  52. def get_auth_chain_difference(
  53. self, room_id: str, state_sets: List[Set[str]]
  54. ) -> Awaitable[Set[str]]:
  55. ...
  56. # We want to await to the reactor occasionally during state res when dealing
  57. # with large data sets, so that we don't exhaust the reactor. This is done by
  58. # awaiting to reactor during loops every N iterations.
  59. _AWAIT_AFTER_ITERATIONS = 100
  60. __all__ = [
  61. "resolve_events_with_store",
  62. ]
  63. async def resolve_events_with_store(
  64. clock: Clock,
  65. room_id: str,
  66. room_version: RoomVersion,
  67. state_sets: Sequence[StateMap[str]],
  68. event_map: Optional[Dict[str, EventBase]],
  69. state_res_store: StateResolutionStore,
  70. ) -> StateMap[str]:
  71. """Resolves the state using the v2 state resolution algorithm
  72. Args:
  73. clock
  74. room_id: the room we are working in
  75. room_version: The room version
  76. state_sets: List of dicts of (type, state_key) -> event_id,
  77. which are the different state groups to resolve.
  78. event_map:
  79. a dict from event_id to event, for any events that we happen to
  80. have in flight (eg, those currently being persisted). This will be
  81. used as a starting point for finding the state we need; any missing
  82. events will be requested via state_res_store.
  83. If None, all events will be fetched via state_res_store.
  84. state_res_store:
  85. Returns:
  86. A map from (type, state_key) to event_id.
  87. """
  88. logger.debug("Computing conflicted state")
  89. # We use event_map as a cache, so if its None we need to initialize it
  90. if event_map is None:
  91. event_map = {}
  92. # First split up the un/conflicted state
  93. unconflicted_state, conflicted_state = _seperate(state_sets)
  94. if not conflicted_state:
  95. return unconflicted_state
  96. logger.debug("%d conflicted state entries", len(conflicted_state))
  97. logger.debug("Calculating auth chain difference")
  98. # Also fetch all auth events that appear in only some of the state sets'
  99. # auth chains.
  100. auth_diff = await _get_auth_chain_difference(
  101. room_id, state_sets, event_map, state_res_store
  102. )
  103. full_conflicted_set = set(
  104. itertools.chain(
  105. itertools.chain.from_iterable(conflicted_state.values()), auth_diff
  106. )
  107. )
  108. events = await state_res_store.get_events(
  109. [eid for eid in full_conflicted_set if eid not in event_map],
  110. allow_rejected=True,
  111. )
  112. event_map.update(events)
  113. # everything in the event map should be in the right room
  114. for event in event_map.values():
  115. if event.room_id != room_id:
  116. raise Exception(
  117. "Attempting to state-resolve for room %s with event %s which is in %s"
  118. % (
  119. room_id,
  120. event.event_id,
  121. event.room_id,
  122. )
  123. )
  124. full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
  125. logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
  126. # Get and sort all the power events (kicks/bans/etc)
  127. power_events = (
  128. eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
  129. )
  130. sorted_power_events = await _reverse_topological_power_sort(
  131. clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
  132. )
  133. logger.debug("sorted %d power events", len(sorted_power_events))
  134. # Now sequentially auth each one
  135. resolved_state = await _iterative_auth_checks(
  136. clock,
  137. room_id,
  138. room_version,
  139. sorted_power_events,
  140. unconflicted_state,
  141. event_map,
  142. state_res_store,
  143. )
  144. logger.debug("resolved power events")
  145. # OK, so we've now resolved the power events. Now sort the remaining
  146. # events using the mainline of the resolved power level.
  147. set_power_events = set(sorted_power_events)
  148. leftover_events = [
  149. ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
  150. ]
  151. logger.debug("sorting %d remaining events", len(leftover_events))
  152. pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
  153. leftover_events = await _mainline_sort(
  154. clock, room_id, leftover_events, pl, event_map, state_res_store
  155. )
  156. logger.debug("resolving remaining events")
  157. resolved_state = await _iterative_auth_checks(
  158. clock,
  159. room_id,
  160. room_version,
  161. leftover_events,
  162. resolved_state,
  163. event_map,
  164. state_res_store,
  165. )
  166. logger.debug("resolved")
  167. # We make sure that unconflicted state always still applies.
  168. resolved_state.update(unconflicted_state)
  169. logger.debug("done")
  170. return resolved_state
  171. async def _get_power_level_for_sender(
  172. room_id: str,
  173. event_id: str,
  174. event_map: Dict[str, EventBase],
  175. state_res_store: StateResolutionStore,
  176. ) -> int:
  177. """Return the power level of the sender of the given event according to
  178. their auth events.
  179. Args:
  180. room_id
  181. event_id
  182. event_map
  183. state_res_store
  184. Returns:
  185. The power level.
  186. """
  187. event = await _get_event(room_id, event_id, event_map, state_res_store)
  188. pl = None
  189. for aid in event.auth_event_ids():
  190. aev = await _get_event(
  191. room_id, aid, event_map, state_res_store, allow_none=True
  192. )
  193. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  194. pl = aev
  195. break
  196. if pl is None:
  197. # Couldn't find power level. Check if they're the creator of the room
  198. for aid in event.auth_event_ids():
  199. aev = await _get_event(
  200. room_id, aid, event_map, state_res_store, allow_none=True
  201. )
  202. if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
  203. if aev.content.get("creator") == event.sender:
  204. return 100
  205. break
  206. return 0
  207. level = pl.content.get("users", {}).get(event.sender)
  208. if level is None:
  209. level = pl.content.get("users_default", 0)
  210. if level is None:
  211. return 0
  212. else:
  213. return int(level)
  214. async def _get_auth_chain_difference(
  215. room_id: str,
  216. state_sets: Sequence[StateMap[str]],
  217. unpersisted_events: Dict[str, EventBase],
  218. state_res_store: StateResolutionStore,
  219. ) -> Set[str]:
  220. """Compare the auth chains of each state set and return the set of events
  221. that only appear in some, but not all of the auth chains.
  222. Args:
  223. state_sets: The input state sets we are trying to resolve across.
  224. unpersisted_events: A map from event ID to EventBase containing all unpersisted
  225. events involved in this resolution.
  226. state_res_store:
  227. Returns:
  228. The auth difference of the given state sets, as a set of event IDs.
  229. """
  230. # The `StateResolutionStore.get_auth_chain_difference` function assumes that
  231. # all events passed to it (and their auth chains) have been persisted
  232. # previously. We need to manually handle any other events that are yet to be
  233. # persisted.
  234. #
  235. # We do this in three steps:
  236. # 1. Compute the set of unpersisted events belonging to the auth difference.
  237. # 2. Replacing any unpersisted events in the state_sets with their auth events,
  238. # recursively, until the state_sets contain only persisted events.
  239. # Then we call `store.get_auth_chain_difference` as normal, which computes
  240. # the set of persisted events belonging to the auth difference.
  241. # 3. Adding the results of 1 and 2 together.
  242. # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
  243. # event IDs if they appear in the `unpersisted_events`. This is the intersection of
  244. # the event's auth chain with the events in `unpersisted_events` *plus* their
  245. # auth event IDs.
  246. events_to_auth_chain: Dict[str, Set[str]] = {}
  247. for event in unpersisted_events.values():
  248. chain = {event.event_id}
  249. events_to_auth_chain[event.event_id] = chain
  250. to_search = [event]
  251. while to_search:
  252. for auth_id in to_search.pop().auth_event_ids():
  253. chain.add(auth_id)
  254. auth_event = unpersisted_events.get(auth_id)
  255. if auth_event:
  256. to_search.append(auth_event)
  257. # We now 1) calculate the auth chain difference for the unpersisted events
  258. # and 2) work out the state sets to pass to the store.
  259. #
  260. # Note: If there are no `unpersisted_events` (which is the common case), we can do a
  261. # much simpler calculation.
  262. if unpersisted_events:
  263. # The list of state sets to pass to the store, where each state set is a set
  264. # of the event ids making up the state. This is similar to `state_sets`,
  265. # except that (a) we only have event ids, not the complete
  266. # ((type, state_key)->event_id) mappings; and (b) we have stripped out
  267. # unpersisted events and replaced them with the persisted events in
  268. # their auth chain.
  269. state_sets_ids: List[Set[str]] = []
  270. # For each state set, the unpersisted event IDs reachable (by their auth
  271. # chain) from the events in that set.
  272. unpersisted_set_ids: List[Set[str]] = []
  273. for state_set in state_sets:
  274. set_ids: Set[str] = set()
  275. state_sets_ids.append(set_ids)
  276. unpersisted_ids: Set[str] = set()
  277. unpersisted_set_ids.append(unpersisted_ids)
  278. for event_id in state_set.values():
  279. event_chain = events_to_auth_chain.get(event_id)
  280. if event_chain is not None:
  281. # We have an unpersisted event. We add all the auth
  282. # events that it references which are also unpersisted.
  283. set_ids.update(
  284. e for e in event_chain if e not in unpersisted_events
  285. )
  286. # We also add the full chain of unpersisted event IDs
  287. # referenced by this state set, so that we can work out the
  288. # auth chain difference of the unpersisted events.
  289. unpersisted_ids.update(
  290. e for e in event_chain if e in unpersisted_events
  291. )
  292. else:
  293. set_ids.add(event_id)
  294. # The auth chain difference of the unpersisted events of the state sets
  295. # is calculated by taking the difference between the union and
  296. # intersections.
  297. union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
  298. intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
  299. auth_difference_unpersisted_part: StrCollection = union - intersection
  300. else:
  301. auth_difference_unpersisted_part = ()
  302. state_sets_ids = [set(state_set.values()) for state_set in state_sets]
  303. difference = await state_res_store.get_auth_chain_difference(
  304. room_id, state_sets_ids
  305. )
  306. difference.update(auth_difference_unpersisted_part)
  307. return difference
  308. def _seperate(
  309. state_sets: Iterable[StateMap[str]],
  310. ) -> Tuple[StateMap[str], StateMap[Set[str]]]:
  311. """Return the unconflicted and conflicted state. This is different than in
  312. the original algorithm, as this defines a key to be conflicted if one of
  313. the state sets doesn't have that key.
  314. Args:
  315. state_sets
  316. Returns:
  317. A tuple of unconflicted and conflicted state. The conflicted state dict
  318. is a map from type/state_key to set of event IDs
  319. """
  320. unconflicted_state = {}
  321. conflicted_state = {}
  322. for key in set(itertools.chain.from_iterable(state_sets)):
  323. event_ids = {state_set.get(key) for state_set in state_sets}
  324. if len(event_ids) == 1:
  325. unconflicted_state[key] = event_ids.pop()
  326. else:
  327. event_ids.discard(None)
  328. conflicted_state[key] = event_ids
  329. # mypy doesn't understand that discarding None above means that conflicted
  330. # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
  331. return unconflicted_state, conflicted_state # type: ignore[return-value]
  332. def _is_power_event(event: EventBase) -> bool:
  333. """Return whether or not the event is a "power event", as defined by the
  334. v2 state resolution algorithm
  335. Args:
  336. event
  337. Returns:
  338. True if the event is a power event.
  339. """
  340. if (event.type, event.state_key) in (
  341. (EventTypes.PowerLevels, ""),
  342. (EventTypes.JoinRules, ""),
  343. (EventTypes.Create, ""),
  344. ):
  345. return True
  346. if event.type == EventTypes.Member:
  347. if event.membership in ("leave", "ban"):
  348. return event.sender != event.state_key
  349. return False
  350. async def _add_event_and_auth_chain_to_graph(
  351. graph: Dict[str, Set[str]],
  352. room_id: str,
  353. event_id: str,
  354. event_map: Dict[str, EventBase],
  355. state_res_store: StateResolutionStore,
  356. full_conflicted_set: Set[str],
  357. ) -> None:
  358. """Helper function for _reverse_topological_power_sort that add the event
  359. and its auth chain (that is in the auth diff) to the graph
  360. Args:
  361. graph: A map from event ID to the events auth event IDs
  362. room_id: the room we are working in
  363. event_id: Event to add to the graph
  364. event_map
  365. state_res_store
  366. full_conflicted_set: Set of event IDs that are in the full conflicted set.
  367. """
  368. state = [event_id]
  369. while state:
  370. eid = state.pop()
  371. graph.setdefault(eid, set())
  372. event = await _get_event(room_id, eid, event_map, state_res_store)
  373. for aid in event.auth_event_ids():
  374. if aid in full_conflicted_set:
  375. if aid not in graph:
  376. state.append(aid)
  377. graph.setdefault(eid, set()).add(aid)
  378. async def _reverse_topological_power_sort(
  379. clock: Clock,
  380. room_id: str,
  381. event_ids: Iterable[str],
  382. event_map: Dict[str, EventBase],
  383. state_res_store: StateResolutionStore,
  384. full_conflicted_set: Set[str],
  385. ) -> List[str]:
  386. """Returns a list of the event_ids sorted by reverse topological ordering,
  387. and then by power level and origin_server_ts
  388. Args:
  389. clock
  390. room_id: the room we are working in
  391. event_ids: The events to sort
  392. event_map
  393. state_res_store
  394. full_conflicted_set: Set of event IDs that are in the full conflicted set.
  395. Returns:
  396. The sorted list
  397. """
  398. graph: Dict[str, Set[str]] = {}
  399. for idx, event_id in enumerate(event_ids, start=1):
  400. await _add_event_and_auth_chain_to_graph(
  401. graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
  402. )
  403. # We await occasionally when we're working with large data sets to
  404. # ensure that we don't block the reactor loop for too long.
  405. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  406. await clock.sleep(0)
  407. event_to_pl = {}
  408. for idx, event_id in enumerate(graph, start=1):
  409. pl = await _get_power_level_for_sender(
  410. room_id, event_id, event_map, state_res_store
  411. )
  412. event_to_pl[event_id] = pl
  413. # We await occasionally when we're working with large data sets to
  414. # ensure that we don't block the reactor loop for too long.
  415. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  416. await clock.sleep(0)
  417. def _get_power_order(event_id: str) -> Tuple[int, int, str]:
  418. ev = event_map[event_id]
  419. pl = event_to_pl[event_id]
  420. return -pl, ev.origin_server_ts, event_id
  421. # Note: graph is modified during the sort
  422. it = lexicographical_topological_sort(graph, key=_get_power_order)
  423. sorted_events = list(it)
  424. return sorted_events
  425. async def _iterative_auth_checks(
  426. clock: Clock,
  427. room_id: str,
  428. room_version: RoomVersion,
  429. event_ids: List[str],
  430. base_state: StateMap[str],
  431. event_map: Dict[str, EventBase],
  432. state_res_store: StateResolutionStore,
  433. ) -> MutableStateMap[str]:
  434. """Sequentially apply auth checks to each event in given list, updating the
  435. state as it goes along.
  436. Args:
  437. clock
  438. room_id
  439. room_version
  440. event_ids: Ordered list of events to apply auth checks to
  441. base_state: The set of state to start with
  442. event_map
  443. state_res_store
  444. Returns:
  445. Returns the final updated state
  446. """
  447. resolved_state = dict(base_state)
  448. for idx, event_id in enumerate(event_ids, start=1):
  449. event = event_map[event_id]
  450. auth_events = {}
  451. for aid in event.auth_event_ids():
  452. ev = await _get_event(
  453. room_id, aid, event_map, state_res_store, allow_none=True
  454. )
  455. if not ev:
  456. logger.warning(
  457. "auth_event id %s for event %s is missing", aid, event_id
  458. )
  459. else:
  460. if ev.rejected_reason is None:
  461. auth_events[(ev.type, ev.state_key)] = ev
  462. for key in event_auth.auth_types_for_event(room_version, event):
  463. if key in resolved_state:
  464. ev_id = resolved_state[key]
  465. ev = await _get_event(room_id, ev_id, event_map, state_res_store)
  466. if ev.rejected_reason is None:
  467. auth_events[key] = event_map[ev_id]
  468. if event.rejected_reason is not None:
  469. # Do not admit previously rejected events into state.
  470. # TODO: This isn't spec compliant. Events that were previously rejected due
  471. # to failing auth checks at their state, but pass auth checks during
  472. # state resolution should be accepted. Synapse does not handle the
  473. # change of rejection status well, so we preserve the previous
  474. # rejection status for now.
  475. #
  476. # Note that events rejected for non-state reasons, such as having the
  477. # wrong auth events, should remain rejected.
  478. #
  479. # https://spec.matrix.org/v1.2/rooms/v9/#rejected-events
  480. # https://github.com/matrix-org/synapse/issues/13797
  481. continue
  482. try:
  483. event_auth.check_state_dependent_auth_rules(
  484. event,
  485. auth_events.values(),
  486. )
  487. resolved_state[(event.type, event.state_key)] = event_id
  488. except AuthError:
  489. pass
  490. # We await occasionally when we're working with large data sets to
  491. # ensure that we don't block the reactor loop for too long.
  492. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  493. await clock.sleep(0)
  494. return resolved_state
  495. async def _mainline_sort(
  496. clock: Clock,
  497. room_id: str,
  498. event_ids: List[str],
  499. resolved_power_event_id: Optional[str],
  500. event_map: Dict[str, EventBase],
  501. state_res_store: StateResolutionStore,
  502. ) -> List[str]:
  503. """Returns a sorted list of event_ids sorted by mainline ordering based on
  504. the given event resolved_power_event_id
  505. Args:
  506. clock
  507. room_id: room we're working in
  508. event_ids: Events to sort
  509. resolved_power_event_id: The final resolved power level event ID
  510. event_map
  511. state_res_store
  512. Returns:
  513. The sorted list
  514. """
  515. if not event_ids:
  516. # It's possible for there to be no event IDs here to sort, so we can
  517. # skip calculating the mainline in that case.
  518. return []
  519. mainline = []
  520. pl = resolved_power_event_id
  521. idx = 0
  522. while pl:
  523. mainline.append(pl)
  524. pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
  525. auth_events = pl_ev.auth_event_ids()
  526. pl = None
  527. for aid in auth_events:
  528. ev = await _get_event(
  529. room_id, aid, event_map, state_res_store, allow_none=True
  530. )
  531. if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
  532. pl = aid
  533. break
  534. # We await occasionally when we're working with large data sets to
  535. # ensure that we don't block the reactor loop for too long.
  536. if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
  537. await clock.sleep(0)
  538. idx += 1
  539. mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
  540. event_ids = list(event_ids)
  541. order_map = {}
  542. for idx, ev_id in enumerate(event_ids, start=1):
  543. depth = await _get_mainline_depth_for_event(
  544. clock, event_map[ev_id], mainline_map, event_map, state_res_store
  545. )
  546. order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
  547. # We await occasionally when we're working with large data sets to
  548. # ensure that we don't block the reactor loop for too long.
  549. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  550. await clock.sleep(0)
  551. event_ids.sort(key=lambda ev_id: order_map[ev_id])
  552. return event_ids
  553. async def _get_mainline_depth_for_event(
  554. clock: Clock,
  555. event: EventBase,
  556. mainline_map: Dict[str, int],
  557. event_map: Dict[str, EventBase],
  558. state_res_store: StateResolutionStore,
  559. ) -> int:
  560. """Get the mainline depths for the given event based on the mainline map
  561. Args:
  562. event
  563. mainline_map: Map from event_id to mainline depth for events in the mainline.
  564. event_map
  565. state_res_store
  566. Returns:
  567. The mainline depth
  568. """
  569. room_id = event.room_id
  570. tmp_event: Optional[EventBase] = event
  571. # We do an iterative search, replacing `event with the power level in its
  572. # auth events (if any)
  573. idx = 0
  574. while tmp_event:
  575. depth = mainline_map.get(tmp_event.event_id)
  576. if depth is not None:
  577. return depth
  578. auth_events = tmp_event.auth_event_ids()
  579. tmp_event = None
  580. for aid in auth_events:
  581. aev = await _get_event(
  582. room_id, aid, event_map, state_res_store, allow_none=True
  583. )
  584. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  585. tmp_event = aev
  586. break
  587. idx += 1
  588. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  589. await clock.sleep(0)
  590. # Didn't find a power level auth event, so we just return 0
  591. return 0
  592. @overload
  593. async def _get_event(
  594. room_id: str,
  595. event_id: str,
  596. event_map: Dict[str, EventBase],
  597. state_res_store: StateResolutionStore,
  598. allow_none: Literal[False] = False,
  599. ) -> EventBase:
  600. ...
  601. @overload
  602. async def _get_event(
  603. room_id: str,
  604. event_id: str,
  605. event_map: Dict[str, EventBase],
  606. state_res_store: StateResolutionStore,
  607. allow_none: Literal[True],
  608. ) -> Optional[EventBase]:
  609. ...
  610. async def _get_event(
  611. room_id: str,
  612. event_id: str,
  613. event_map: Dict[str, EventBase],
  614. state_res_store: StateResolutionStore,
  615. allow_none: bool = False,
  616. ) -> Optional[EventBase]:
  617. """Helper function to look up event in event_map, falling back to looking
  618. it up in the store
  619. Args:
  620. room_id
  621. event_id
  622. event_map
  623. state_res_store
  624. allow_none: if the event is not found, return None rather than raising
  625. an exception
  626. Returns:
  627. The event, or none if the event does not exist (and allow_none is True).
  628. """
  629. if event_id not in event_map:
  630. events = await state_res_store.get_events([event_id], allow_rejected=True)
  631. event_map.update(events)
  632. event = event_map.get(event_id)
  633. if event is None:
  634. if allow_none:
  635. return None
  636. raise Exception("Unknown event %s" % (event_id,))
  637. if event.room_id != room_id:
  638. raise Exception(
  639. "In state res for room %s, event %s is in %s"
  640. % (room_id, event_id, event.room_id)
  641. )
  642. return event
  643. def lexicographical_topological_sort(
  644. graph: Dict[str, Set[str]], key: Callable[[str], Any]
  645. ) -> Generator[str, None, None]:
  646. """Performs a lexicographic reverse topological sort on the graph.
  647. This returns a reverse topological sort (i.e. if node A references B then B
  648. appears before A in the sort), with ties broken lexicographically based on
  649. return value of the `key` function.
  650. NOTE: `graph` is modified during the sort.
  651. Args:
  652. graph: A representation of the graph where each node is a key in the
  653. dict and its value are the nodes edges.
  654. key: A function that takes a node and returns a value that is comparable
  655. and used to order nodes
  656. Yields:
  657. The next node in the topological sort
  658. """
  659. # Note, this is basically Kahn's algorithm except we look at nodes with no
  660. # outgoing edges, c.f.
  661. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
  662. outdegree_map = graph
  663. reverse_graph: Dict[str, Set[str]] = {}
  664. # Lists of nodes with zero out degree. Is actually a tuple of
  665. # `(key(node), node)` so that sorting does the right thing
  666. zero_outdegree = []
  667. for node, edges in graph.items():
  668. if len(edges) == 0:
  669. zero_outdegree.append((key(node), node))
  670. reverse_graph.setdefault(node, set())
  671. for edge in edges:
  672. reverse_graph.setdefault(edge, set()).add(node)
  673. # heapq is a built in implementation of a sorted queue.
  674. heapq.heapify(zero_outdegree)
  675. while zero_outdegree:
  676. _, node = heapq.heappop(zero_outdegree)
  677. for parent in reverse_graph[node]:
  678. out = outdegree_map[parent]
  679. out.discard(node)
  680. if len(out) == 0:
  681. heapq.heappush(zero_outdegree, (key(parent), parent))
  682. yield node