You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

639 lines
23 KiB

  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import List, Tuple, cast
  16. from immutabledict import immutabledict
  17. from twisted.test.proto_helpers import MemoryReactor
  18. from synapse.api.constants import EventTypes, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.events import EventBase
  21. from synapse.server import HomeServer
  22. from synapse.types import JsonDict, RoomID, StateMap, UserID
  23. from synapse.types.state import StateFilter
  24. from synapse.util import Clock
  25. from tests.unittest import HomeserverTestCase
  26. logger = logging.getLogger(__name__)
  27. class StateStoreTestCase(HomeserverTestCase):
  28. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  29. self.store = hs.get_datastores().main
  30. self.storage = hs.get_storage_controllers()
  31. self.state_datastore = self.storage.state.stores.state
  32. self.event_builder_factory = hs.get_event_builder_factory()
  33. self.event_creation_handler = hs.get_event_creation_handler()
  34. self.u_alice = UserID.from_string("@alice:test")
  35. self.u_bob = UserID.from_string("@bob:test")
  36. self.room = RoomID.from_string("!abc123:test")
  37. self.get_success(
  38. self.store.store_room(
  39. self.room.to_string(),
  40. room_creator_user_id="@creator:text",
  41. is_public=True,
  42. room_version=RoomVersions.V1,
  43. )
  44. )
  45. def inject_state_event(
  46. self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
  47. ) -> EventBase:
  48. builder = self.event_builder_factory.for_room_version(
  49. RoomVersions.V1,
  50. {
  51. "type": typ,
  52. "sender": sender.to_string(),
  53. "state_key": state_key,
  54. "room_id": room.to_string(),
  55. "content": content,
  56. },
  57. )
  58. event, unpersisted_context = self.get_success(
  59. self.event_creation_handler.create_new_client_event(builder)
  60. )
  61. context = self.get_success(unpersisted_context.persist(event))
  62. assert self.storage.persistence is not None
  63. self.get_success(self.storage.persistence.persist_event(event, context))
  64. return event
  65. def assertStateMapEqual(
  66. self, s1: StateMap[EventBase], s2: StateMap[EventBase]
  67. ) -> None:
  68. for t in s1:
  69. # just compare event IDs for simplicity
  70. self.assertEqual(s1[t].event_id, s2[t].event_id)
  71. self.assertEqual(len(s1), len(s2))
  72. def test_get_state_groups_ids(self) -> None:
  73. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  74. e2 = self.inject_state_event(
  75. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  76. )
  77. state_group_map = self.get_success(
  78. self.storage.state.get_state_groups_ids(
  79. self.room.to_string(), [e2.event_id]
  80. )
  81. )
  82. self.assertEqual(len(state_group_map), 1)
  83. state_map = list(state_group_map.values())[0]
  84. self.assertDictEqual(
  85. state_map,
  86. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  87. )
  88. def test_get_state_groups(self) -> None:
  89. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  90. e2 = self.inject_state_event(
  91. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  92. )
  93. state_group_map = self.get_success(
  94. self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
  95. )
  96. self.assertEqual(len(state_group_map), 1)
  97. state_list = list(state_group_map.values())[0]
  98. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  99. def test_get_state_for_event(self) -> None:
  100. # this defaults to a linear DAG as each new injection defaults to whatever
  101. # forward extremities are currently in the DB for this room.
  102. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  103. e2 = self.inject_state_event(
  104. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  105. )
  106. e3 = self.inject_state_event(
  107. self.room,
  108. self.u_alice,
  109. EventTypes.Member,
  110. self.u_alice.to_string(),
  111. {"membership": Membership.JOIN},
  112. )
  113. e4 = self.inject_state_event(
  114. self.room,
  115. self.u_bob,
  116. EventTypes.Member,
  117. self.u_bob.to_string(),
  118. {"membership": Membership.JOIN},
  119. )
  120. e5 = self.inject_state_event(
  121. self.room,
  122. self.u_bob,
  123. EventTypes.Member,
  124. self.u_bob.to_string(),
  125. {"membership": Membership.LEAVE},
  126. )
  127. # check we get the full state as of the final event
  128. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  129. self.assertIsNotNone(e4)
  130. self.assertStateMapEqual(
  131. {
  132. (e1.type, e1.state_key): e1,
  133. (e2.type, e2.state_key): e2,
  134. (e3.type, e3.state_key): e3,
  135. # e4 is overwritten by e5
  136. (e5.type, e5.state_key): e5,
  137. },
  138. state,
  139. )
  140. # check we can filter to the m.room.name event (with a '' state key)
  141. state = self.get_success(
  142. self.storage.state.get_state_for_event(
  143. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  144. )
  145. )
  146. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  147. # check we can filter to the m.room.name event (with a wildcard None state key)
  148. state = self.get_success(
  149. self.storage.state.get_state_for_event(
  150. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  151. )
  152. )
  153. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  154. # check we can grab the m.room.member events (with a wildcard None state key)
  155. state = self.get_success(
  156. self.storage.state.get_state_for_event(
  157. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  158. )
  159. )
  160. self.assertStateMapEqual(
  161. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  162. )
  163. # check we can grab a specific room member without filtering out the
  164. # other event types
  165. state = self.get_success(
  166. self.storage.state.get_state_for_event(
  167. e5.event_id,
  168. state_filter=StateFilter(
  169. types=immutabledict(
  170. {EventTypes.Member: frozenset({self.u_alice.to_string()})}
  171. ),
  172. include_others=True,
  173. ),
  174. )
  175. )
  176. self.assertStateMapEqual(
  177. {
  178. (e1.type, e1.state_key): e1,
  179. (e2.type, e2.state_key): e2,
  180. (e3.type, e3.state_key): e3,
  181. },
  182. state,
  183. )
  184. # check that we can grab everything except members
  185. state = self.get_success(
  186. self.storage.state.get_state_for_event(
  187. e5.event_id,
  188. state_filter=StateFilter(
  189. types=immutabledict({EventTypes.Member: frozenset()}),
  190. include_others=True,
  191. ),
  192. )
  193. )
  194. self.assertStateMapEqual(
  195. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  196. )
  197. #######################################################
  198. # _get_state_for_group_using_cache tests against a full cache
  199. #######################################################
  200. room_id = self.room.to_string()
  201. group_ids = self.get_success(
  202. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  203. )
  204. group = list(group_ids.keys())[0]
  205. # test _get_state_for_group_using_cache correctly filters out members
  206. # with types=[]
  207. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  208. self.state_datastore._state_group_cache,
  209. group,
  210. state_filter=StateFilter(
  211. types=immutabledict({EventTypes.Member: frozenset()}),
  212. include_others=True,
  213. ),
  214. )
  215. self.assertEqual(is_all, True)
  216. self.assertDictEqual(
  217. {
  218. (e1.type, e1.state_key): e1.event_id,
  219. (e2.type, e2.state_key): e2.event_id,
  220. },
  221. state_dict,
  222. )
  223. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  224. self.state_datastore._state_group_members_cache,
  225. group,
  226. state_filter=StateFilter(
  227. types=immutabledict({EventTypes.Member: frozenset()}),
  228. include_others=True,
  229. ),
  230. )
  231. self.assertEqual(is_all, True)
  232. self.assertDictEqual({}, state_dict)
  233. # test _get_state_for_group_using_cache correctly filters in members
  234. # with wildcard types
  235. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  236. self.state_datastore._state_group_cache,
  237. group,
  238. state_filter=StateFilter(
  239. types=immutabledict({EventTypes.Member: None}), include_others=True
  240. ),
  241. )
  242. self.assertEqual(is_all, True)
  243. self.assertDictEqual(
  244. {
  245. (e1.type, e1.state_key): e1.event_id,
  246. (e2.type, e2.state_key): e2.event_id,
  247. },
  248. state_dict,
  249. )
  250. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  251. self.state_datastore._state_group_members_cache,
  252. group,
  253. state_filter=StateFilter(
  254. types=immutabledict({EventTypes.Member: None}), include_others=True
  255. ),
  256. )
  257. self.assertEqual(is_all, True)
  258. self.assertDictEqual(
  259. {
  260. (e3.type, e3.state_key): e3.event_id,
  261. # e4 is overwritten by e5
  262. (e5.type, e5.state_key): e5.event_id,
  263. },
  264. state_dict,
  265. )
  266. # test _get_state_for_group_using_cache correctly filters in members
  267. # with specific types
  268. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  269. self.state_datastore._state_group_cache,
  270. group,
  271. state_filter=StateFilter(
  272. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  273. include_others=True,
  274. ),
  275. )
  276. self.assertEqual(is_all, True)
  277. self.assertDictEqual(
  278. {
  279. (e1.type, e1.state_key): e1.event_id,
  280. (e2.type, e2.state_key): e2.event_id,
  281. },
  282. state_dict,
  283. )
  284. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  285. self.state_datastore._state_group_members_cache,
  286. group,
  287. state_filter=StateFilter(
  288. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  289. include_others=True,
  290. ),
  291. )
  292. self.assertEqual(is_all, True)
  293. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  294. # test _get_state_for_group_using_cache correctly filters in members
  295. # with specific types
  296. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  297. self.state_datastore._state_group_members_cache,
  298. group,
  299. state_filter=StateFilter(
  300. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  301. include_others=False,
  302. ),
  303. )
  304. self.assertEqual(is_all, True)
  305. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  306. #######################################################
  307. # deliberately remove e2 (room name) from the _state_group_cache
  308. cache_entry = self.state_datastore._state_group_cache.get(group)
  309. state_dict_ids = cache_entry.value
  310. self.assertEqual(cache_entry.full, True)
  311. self.assertEqual(cache_entry.known_absent, set())
  312. self.assertDictEqual(
  313. state_dict_ids,
  314. {
  315. (e1.type, e1.state_key): e1.event_id,
  316. (e2.type, e2.state_key): e2.event_id,
  317. },
  318. )
  319. state_dict_ids.pop((e2.type, e2.state_key))
  320. self.state_datastore._state_group_cache.invalidate(group)
  321. self.state_datastore._state_group_cache.update(
  322. sequence=self.state_datastore._state_group_cache.sequence,
  323. key=group,
  324. value=state_dict_ids,
  325. # list fetched keys so it knows it's partial
  326. fetched_keys=((e1.type, e1.state_key),),
  327. )
  328. cache_entry = self.state_datastore._state_group_cache.get(group)
  329. state_dict_ids = cache_entry.value
  330. self.assertEqual(cache_entry.full, False)
  331. self.assertEqual(cache_entry.known_absent, set())
  332. self.assertDictEqual(state_dict_ids, {})
  333. ############################################
  334. # test that things work with a partial cache
  335. # test _get_state_for_group_using_cache correctly filters out members
  336. # with types=[]
  337. room_id = self.room.to_string()
  338. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  339. self.state_datastore._state_group_cache,
  340. group,
  341. state_filter=StateFilter(
  342. types=immutabledict({EventTypes.Member: frozenset()}),
  343. include_others=True,
  344. ),
  345. )
  346. self.assertEqual(is_all, False)
  347. self.assertDictEqual({}, state_dict)
  348. room_id = self.room.to_string()
  349. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  350. self.state_datastore._state_group_members_cache,
  351. group,
  352. state_filter=StateFilter(
  353. types=immutabledict({EventTypes.Member: frozenset()}),
  354. include_others=True,
  355. ),
  356. )
  357. self.assertEqual(is_all, True)
  358. self.assertDictEqual({}, state_dict)
  359. # test _get_state_for_group_using_cache correctly filters in members
  360. # wildcard types
  361. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  362. self.state_datastore._state_group_cache,
  363. group,
  364. state_filter=StateFilter(
  365. types=immutabledict({EventTypes.Member: None}), include_others=True
  366. ),
  367. )
  368. self.assertEqual(is_all, False)
  369. self.assertDictEqual({}, state_dict)
  370. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  371. self.state_datastore._state_group_members_cache,
  372. group,
  373. state_filter=StateFilter(
  374. types=immutabledict({EventTypes.Member: None}), include_others=True
  375. ),
  376. )
  377. self.assertEqual(is_all, True)
  378. self.assertDictEqual(
  379. {
  380. (e3.type, e3.state_key): e3.event_id,
  381. (e5.type, e5.state_key): e5.event_id,
  382. },
  383. state_dict,
  384. )
  385. # test _get_state_for_group_using_cache correctly filters in members
  386. # with specific types
  387. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  388. self.state_datastore._state_group_cache,
  389. group,
  390. state_filter=StateFilter(
  391. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  392. include_others=True,
  393. ),
  394. )
  395. self.assertEqual(is_all, False)
  396. self.assertDictEqual({}, state_dict)
  397. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  398. self.state_datastore._state_group_members_cache,
  399. group,
  400. state_filter=StateFilter(
  401. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  402. include_others=True,
  403. ),
  404. )
  405. self.assertEqual(is_all, True)
  406. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  407. # test _get_state_for_group_using_cache correctly filters in members
  408. # with specific types
  409. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  410. self.state_datastore._state_group_cache,
  411. group,
  412. state_filter=StateFilter(
  413. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  414. include_others=False,
  415. ),
  416. )
  417. self.assertEqual(is_all, False)
  418. self.assertDictEqual({}, state_dict)
  419. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  420. self.state_datastore._state_group_members_cache,
  421. group,
  422. state_filter=StateFilter(
  423. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  424. include_others=False,
  425. ),
  426. )
  427. self.assertEqual(is_all, True)
  428. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  429. def test_batched_state_group_storing(self) -> None:
  430. creation_event = self.inject_state_event(
  431. self.room, self.u_alice, EventTypes.Create, "", {}
  432. )
  433. state_to_event = self.get_success(
  434. self.storage.state.get_state_groups(
  435. self.room.to_string(), [creation_event.event_id]
  436. )
  437. )
  438. current_state_group = list(state_to_event.keys())[0]
  439. # create some unpersisted events and event contexts to store against room
  440. events_and_context = []
  441. builder = self.event_builder_factory.for_room_version(
  442. RoomVersions.V1,
  443. {
  444. "type": EventTypes.Name,
  445. "sender": self.u_alice.to_string(),
  446. "state_key": "",
  447. "room_id": self.room.to_string(),
  448. "content": {"name": "first rename of room"},
  449. },
  450. )
  451. event1, unpersisted_context1 = self.get_success(
  452. self.event_creation_handler.create_new_client_event(builder)
  453. )
  454. events_and_context.append((event1, unpersisted_context1))
  455. builder2 = self.event_builder_factory.for_room_version(
  456. RoomVersions.V1,
  457. {
  458. "type": EventTypes.JoinRules,
  459. "sender": self.u_alice.to_string(),
  460. "state_key": "",
  461. "room_id": self.room.to_string(),
  462. "content": {"join_rule": "private"},
  463. },
  464. )
  465. event2, unpersisted_context2 = self.get_success(
  466. self.event_creation_handler.create_new_client_event(builder2)
  467. )
  468. events_and_context.append((event2, unpersisted_context2))
  469. builder3 = self.event_builder_factory.for_room_version(
  470. RoomVersions.V1,
  471. {
  472. "type": EventTypes.Message,
  473. "sender": self.u_alice.to_string(),
  474. "room_id": self.room.to_string(),
  475. "content": {"body": "hello from event 3", "msgtype": "m.text"},
  476. },
  477. )
  478. event3, unpersisted_context3 = self.get_success(
  479. self.event_creation_handler.create_new_client_event(builder3)
  480. )
  481. events_and_context.append((event3, unpersisted_context3))
  482. builder4 = self.event_builder_factory.for_room_version(
  483. RoomVersions.V1,
  484. {
  485. "type": EventTypes.JoinRules,
  486. "sender": self.u_alice.to_string(),
  487. "state_key": "",
  488. "room_id": self.room.to_string(),
  489. "content": {"join_rule": "public"},
  490. },
  491. )
  492. event4, unpersisted_context4 = self.get_success(
  493. self.event_creation_handler.create_new_client_event(builder4)
  494. )
  495. events_and_context.append((event4, unpersisted_context4))
  496. processed_events_and_context = self.get_success(
  497. self.hs.get_datastores().state.store_state_deltas_for_batched(
  498. events_and_context, self.room.to_string(), current_state_group
  499. )
  500. )
  501. # check that only state events are in state_groups, and all state events are in state_groups
  502. res = cast(
  503. List[Tuple[str]],
  504. self.get_success(
  505. self.store.db_pool.simple_select_list(
  506. table="state_groups",
  507. keyvalues=None,
  508. retcols=("event_id",),
  509. )
  510. ),
  511. )
  512. events = []
  513. for result in res:
  514. self.assertNotIn(event3.event_id, result) # XXX
  515. events.append(result[0])
  516. for event, _ in processed_events_and_context:
  517. if event.is_state():
  518. self.assertIn(event.event_id, events)
  519. # check that each unique state has state group in state_groups_state and that the
  520. # type/state key is correct, and check that each state event's state group
  521. # has an entry and prev event in state_group_edges
  522. for event, context in processed_events_and_context:
  523. if event.is_state():
  524. state = cast(
  525. List[Tuple[str, str]],
  526. self.get_success(
  527. self.store.db_pool.simple_select_list(
  528. table="state_groups_state",
  529. keyvalues={"state_group": context.state_group_after_event},
  530. retcols=("type", "state_key"),
  531. )
  532. ),
  533. )
  534. self.assertEqual(event.type, state[0][0])
  535. self.assertEqual(event.state_key, state[0][1])
  536. groups = cast(
  537. List[Tuple[str]],
  538. self.get_success(
  539. self.store.db_pool.simple_select_list(
  540. table="state_group_edges",
  541. keyvalues={
  542. "state_group": str(context.state_group_after_event)
  543. },
  544. retcols=("prev_state_group",),
  545. )
  546. ),
  547. )
  548. self.assertEqual(context.state_group_before_event, groups[0][0])