You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

629 lines
23 KiB

  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from immutabledict import immutabledict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.events import EventBase
  20. from synapse.server import HomeServer
  21. from synapse.types import JsonDict, RoomID, StateMap, UserID
  22. from synapse.types.state import StateFilter
  23. from synapse.util import Clock
  24. from tests.unittest import HomeserverTestCase
  25. logger = logging.getLogger(__name__)
  26. class StateStoreTestCase(HomeserverTestCase):
  27. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  28. self.store = hs.get_datastores().main
  29. self.storage = hs.get_storage_controllers()
  30. self.state_datastore = self.storage.state.stores.state
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. self.get_success(
  37. self.store.store_room(
  38. self.room.to_string(),
  39. room_creator_user_id="@creator:text",
  40. is_public=True,
  41. room_version=RoomVersions.V1,
  42. )
  43. )
  44. def inject_state_event(
  45. self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
  46. ) -> EventBase:
  47. builder = self.event_builder_factory.for_room_version(
  48. RoomVersions.V1,
  49. {
  50. "type": typ,
  51. "sender": sender.to_string(),
  52. "state_key": state_key,
  53. "room_id": room.to_string(),
  54. "content": content,
  55. },
  56. )
  57. event, unpersisted_context = self.get_success(
  58. self.event_creation_handler.create_new_client_event(builder)
  59. )
  60. context = self.get_success(unpersisted_context.persist(event))
  61. assert self.storage.persistence is not None
  62. self.get_success(self.storage.persistence.persist_event(event, context))
  63. return event
  64. def assertStateMapEqual(
  65. self, s1: StateMap[EventBase], s2: StateMap[EventBase]
  66. ) -> None:
  67. for t in s1:
  68. # just compare event IDs for simplicity
  69. self.assertEqual(s1[t].event_id, s2[t].event_id)
  70. self.assertEqual(len(s1), len(s2))
  71. def test_get_state_groups_ids(self) -> None:
  72. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  73. e2 = self.inject_state_event(
  74. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  75. )
  76. state_group_map = self.get_success(
  77. self.storage.state.get_state_groups_ids(
  78. self.room.to_string(), [e2.event_id]
  79. )
  80. )
  81. self.assertEqual(len(state_group_map), 1)
  82. state_map = list(state_group_map.values())[0]
  83. self.assertDictEqual(
  84. state_map,
  85. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  86. )
  87. def test_get_state_groups(self) -> None:
  88. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  89. e2 = self.inject_state_event(
  90. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  91. )
  92. state_group_map = self.get_success(
  93. self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
  94. )
  95. self.assertEqual(len(state_group_map), 1)
  96. state_list = list(state_group_map.values())[0]
  97. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  98. def test_get_state_for_event(self) -> None:
  99. # this defaults to a linear DAG as each new injection defaults to whatever
  100. # forward extremities are currently in the DB for this room.
  101. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  102. e2 = self.inject_state_event(
  103. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  104. )
  105. e3 = self.inject_state_event(
  106. self.room,
  107. self.u_alice,
  108. EventTypes.Member,
  109. self.u_alice.to_string(),
  110. {"membership": Membership.JOIN},
  111. )
  112. e4 = self.inject_state_event(
  113. self.room,
  114. self.u_bob,
  115. EventTypes.Member,
  116. self.u_bob.to_string(),
  117. {"membership": Membership.JOIN},
  118. )
  119. e5 = self.inject_state_event(
  120. self.room,
  121. self.u_bob,
  122. EventTypes.Member,
  123. self.u_bob.to_string(),
  124. {"membership": Membership.LEAVE},
  125. )
  126. # check we get the full state as of the final event
  127. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  128. self.assertIsNotNone(e4)
  129. self.assertStateMapEqual(
  130. {
  131. (e1.type, e1.state_key): e1,
  132. (e2.type, e2.state_key): e2,
  133. (e3.type, e3.state_key): e3,
  134. # e4 is overwritten by e5
  135. (e5.type, e5.state_key): e5,
  136. },
  137. state,
  138. )
  139. # check we can filter to the m.room.name event (with a '' state key)
  140. state = self.get_success(
  141. self.storage.state.get_state_for_event(
  142. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  143. )
  144. )
  145. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  146. # check we can filter to the m.room.name event (with a wildcard None state key)
  147. state = self.get_success(
  148. self.storage.state.get_state_for_event(
  149. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  150. )
  151. )
  152. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  153. # check we can grab the m.room.member events (with a wildcard None state key)
  154. state = self.get_success(
  155. self.storage.state.get_state_for_event(
  156. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  157. )
  158. )
  159. self.assertStateMapEqual(
  160. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  161. )
  162. # check we can grab a specific room member without filtering out the
  163. # other event types
  164. state = self.get_success(
  165. self.storage.state.get_state_for_event(
  166. e5.event_id,
  167. state_filter=StateFilter(
  168. types=immutabledict(
  169. {EventTypes.Member: frozenset({self.u_alice.to_string()})}
  170. ),
  171. include_others=True,
  172. ),
  173. )
  174. )
  175. self.assertStateMapEqual(
  176. {
  177. (e1.type, e1.state_key): e1,
  178. (e2.type, e2.state_key): e2,
  179. (e3.type, e3.state_key): e3,
  180. },
  181. state,
  182. )
  183. # check that we can grab everything except members
  184. state = self.get_success(
  185. self.storage.state.get_state_for_event(
  186. e5.event_id,
  187. state_filter=StateFilter(
  188. types=immutabledict({EventTypes.Member: frozenset()}),
  189. include_others=True,
  190. ),
  191. )
  192. )
  193. self.assertStateMapEqual(
  194. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  195. )
  196. #######################################################
  197. # _get_state_for_group_using_cache tests against a full cache
  198. #######################################################
  199. room_id = self.room.to_string()
  200. group_ids = self.get_success(
  201. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  202. )
  203. group = list(group_ids.keys())[0]
  204. # test _get_state_for_group_using_cache correctly filters out members
  205. # with types=[]
  206. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  207. self.state_datastore._state_group_cache,
  208. group,
  209. state_filter=StateFilter(
  210. types=immutabledict({EventTypes.Member: frozenset()}),
  211. include_others=True,
  212. ),
  213. )
  214. self.assertEqual(is_all, True)
  215. self.assertDictEqual(
  216. {
  217. (e1.type, e1.state_key): e1.event_id,
  218. (e2.type, e2.state_key): e2.event_id,
  219. },
  220. state_dict,
  221. )
  222. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  223. self.state_datastore._state_group_members_cache,
  224. group,
  225. state_filter=StateFilter(
  226. types=immutabledict({EventTypes.Member: frozenset()}),
  227. include_others=True,
  228. ),
  229. )
  230. self.assertEqual(is_all, True)
  231. self.assertDictEqual({}, state_dict)
  232. # test _get_state_for_group_using_cache correctly filters in members
  233. # with wildcard types
  234. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  235. self.state_datastore._state_group_cache,
  236. group,
  237. state_filter=StateFilter(
  238. types=immutabledict({EventTypes.Member: None}), include_others=True
  239. ),
  240. )
  241. self.assertEqual(is_all, True)
  242. self.assertDictEqual(
  243. {
  244. (e1.type, e1.state_key): e1.event_id,
  245. (e2.type, e2.state_key): e2.event_id,
  246. },
  247. state_dict,
  248. )
  249. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  250. self.state_datastore._state_group_members_cache,
  251. group,
  252. state_filter=StateFilter(
  253. types=immutabledict({EventTypes.Member: None}), include_others=True
  254. ),
  255. )
  256. self.assertEqual(is_all, True)
  257. self.assertDictEqual(
  258. {
  259. (e3.type, e3.state_key): e3.event_id,
  260. # e4 is overwritten by e5
  261. (e5.type, e5.state_key): e5.event_id,
  262. },
  263. state_dict,
  264. )
  265. # test _get_state_for_group_using_cache correctly filters in members
  266. # with specific types
  267. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  268. self.state_datastore._state_group_cache,
  269. group,
  270. state_filter=StateFilter(
  271. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  272. include_others=True,
  273. ),
  274. )
  275. self.assertEqual(is_all, True)
  276. self.assertDictEqual(
  277. {
  278. (e1.type, e1.state_key): e1.event_id,
  279. (e2.type, e2.state_key): e2.event_id,
  280. },
  281. state_dict,
  282. )
  283. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  284. self.state_datastore._state_group_members_cache,
  285. group,
  286. state_filter=StateFilter(
  287. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  288. include_others=True,
  289. ),
  290. )
  291. self.assertEqual(is_all, True)
  292. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  293. # test _get_state_for_group_using_cache correctly filters in members
  294. # with specific types
  295. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  296. self.state_datastore._state_group_members_cache,
  297. group,
  298. state_filter=StateFilter(
  299. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  300. include_others=False,
  301. ),
  302. )
  303. self.assertEqual(is_all, True)
  304. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  305. #######################################################
  306. # deliberately remove e2 (room name) from the _state_group_cache
  307. cache_entry = self.state_datastore._state_group_cache.get(group)
  308. state_dict_ids = cache_entry.value
  309. self.assertEqual(cache_entry.full, True)
  310. self.assertEqual(cache_entry.known_absent, set())
  311. self.assertDictEqual(
  312. state_dict_ids,
  313. {
  314. (e1.type, e1.state_key): e1.event_id,
  315. (e2.type, e2.state_key): e2.event_id,
  316. },
  317. )
  318. state_dict_ids.pop((e2.type, e2.state_key))
  319. self.state_datastore._state_group_cache.invalidate(group)
  320. self.state_datastore._state_group_cache.update(
  321. sequence=self.state_datastore._state_group_cache.sequence,
  322. key=group,
  323. value=state_dict_ids,
  324. # list fetched keys so it knows it's partial
  325. fetched_keys=((e1.type, e1.state_key),),
  326. )
  327. cache_entry = self.state_datastore._state_group_cache.get(group)
  328. state_dict_ids = cache_entry.value
  329. self.assertEqual(cache_entry.full, False)
  330. self.assertEqual(cache_entry.known_absent, set())
  331. self.assertDictEqual(state_dict_ids, {})
  332. ############################################
  333. # test that things work with a partial cache
  334. # test _get_state_for_group_using_cache correctly filters out members
  335. # with types=[]
  336. room_id = self.room.to_string()
  337. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  338. self.state_datastore._state_group_cache,
  339. group,
  340. state_filter=StateFilter(
  341. types=immutabledict({EventTypes.Member: frozenset()}),
  342. include_others=True,
  343. ),
  344. )
  345. self.assertEqual(is_all, False)
  346. self.assertDictEqual({}, state_dict)
  347. room_id = self.room.to_string()
  348. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  349. self.state_datastore._state_group_members_cache,
  350. group,
  351. state_filter=StateFilter(
  352. types=immutabledict({EventTypes.Member: frozenset()}),
  353. include_others=True,
  354. ),
  355. )
  356. self.assertEqual(is_all, True)
  357. self.assertDictEqual({}, state_dict)
  358. # test _get_state_for_group_using_cache correctly filters in members
  359. # wildcard types
  360. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  361. self.state_datastore._state_group_cache,
  362. group,
  363. state_filter=StateFilter(
  364. types=immutabledict({EventTypes.Member: None}), include_others=True
  365. ),
  366. )
  367. self.assertEqual(is_all, False)
  368. self.assertDictEqual({}, state_dict)
  369. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  370. self.state_datastore._state_group_members_cache,
  371. group,
  372. state_filter=StateFilter(
  373. types=immutabledict({EventTypes.Member: None}), include_others=True
  374. ),
  375. )
  376. self.assertEqual(is_all, True)
  377. self.assertDictEqual(
  378. {
  379. (e3.type, e3.state_key): e3.event_id,
  380. (e5.type, e5.state_key): e5.event_id,
  381. },
  382. state_dict,
  383. )
  384. # test _get_state_for_group_using_cache correctly filters in members
  385. # with specific types
  386. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  387. self.state_datastore._state_group_cache,
  388. group,
  389. state_filter=StateFilter(
  390. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  391. include_others=True,
  392. ),
  393. )
  394. self.assertEqual(is_all, False)
  395. self.assertDictEqual({}, state_dict)
  396. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  397. self.state_datastore._state_group_members_cache,
  398. group,
  399. state_filter=StateFilter(
  400. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  401. include_others=True,
  402. ),
  403. )
  404. self.assertEqual(is_all, True)
  405. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  406. # test _get_state_for_group_using_cache correctly filters in members
  407. # with specific types
  408. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  409. self.state_datastore._state_group_cache,
  410. group,
  411. state_filter=StateFilter(
  412. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  413. include_others=False,
  414. ),
  415. )
  416. self.assertEqual(is_all, False)
  417. self.assertDictEqual({}, state_dict)
  418. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  419. self.state_datastore._state_group_members_cache,
  420. group,
  421. state_filter=StateFilter(
  422. types=immutabledict({EventTypes.Member: frozenset({e5.state_key})}),
  423. include_others=False,
  424. ),
  425. )
  426. self.assertEqual(is_all, True)
  427. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  428. def test_batched_state_group_storing(self) -> None:
  429. creation_event = self.inject_state_event(
  430. self.room, self.u_alice, EventTypes.Create, "", {}
  431. )
  432. state_to_event = self.get_success(
  433. self.storage.state.get_state_groups(
  434. self.room.to_string(), [creation_event.event_id]
  435. )
  436. )
  437. current_state_group = list(state_to_event.keys())[0]
  438. # create some unpersisted events and event contexts to store against room
  439. events_and_context = []
  440. builder = self.event_builder_factory.for_room_version(
  441. RoomVersions.V1,
  442. {
  443. "type": EventTypes.Name,
  444. "sender": self.u_alice.to_string(),
  445. "state_key": "",
  446. "room_id": self.room.to_string(),
  447. "content": {"name": "first rename of room"},
  448. },
  449. )
  450. event1, unpersisted_context1 = self.get_success(
  451. self.event_creation_handler.create_new_client_event(builder)
  452. )
  453. events_and_context.append((event1, unpersisted_context1))
  454. builder2 = self.event_builder_factory.for_room_version(
  455. RoomVersions.V1,
  456. {
  457. "type": EventTypes.JoinRules,
  458. "sender": self.u_alice.to_string(),
  459. "state_key": "",
  460. "room_id": self.room.to_string(),
  461. "content": {"join_rule": "private"},
  462. },
  463. )
  464. event2, unpersisted_context2 = self.get_success(
  465. self.event_creation_handler.create_new_client_event(builder2)
  466. )
  467. events_and_context.append((event2, unpersisted_context2))
  468. builder3 = self.event_builder_factory.for_room_version(
  469. RoomVersions.V1,
  470. {
  471. "type": EventTypes.Message,
  472. "sender": self.u_alice.to_string(),
  473. "room_id": self.room.to_string(),
  474. "content": {"body": "hello from event 3", "msgtype": "m.text"},
  475. },
  476. )
  477. event3, unpersisted_context3 = self.get_success(
  478. self.event_creation_handler.create_new_client_event(builder3)
  479. )
  480. events_and_context.append((event3, unpersisted_context3))
  481. builder4 = self.event_builder_factory.for_room_version(
  482. RoomVersions.V1,
  483. {
  484. "type": EventTypes.JoinRules,
  485. "sender": self.u_alice.to_string(),
  486. "state_key": "",
  487. "room_id": self.room.to_string(),
  488. "content": {"join_rule": "public"},
  489. },
  490. )
  491. event4, unpersisted_context4 = self.get_success(
  492. self.event_creation_handler.create_new_client_event(builder4)
  493. )
  494. events_and_context.append((event4, unpersisted_context4))
  495. processed_events_and_context = self.get_success(
  496. self.hs.get_datastores().state.store_state_deltas_for_batched(
  497. events_and_context, self.room.to_string(), current_state_group
  498. )
  499. )
  500. # check that only state events are in state_groups, and all state events are in state_groups
  501. res = self.get_success(
  502. self.store.db_pool.simple_select_list(
  503. table="state_groups",
  504. keyvalues=None,
  505. retcols=("event_id",),
  506. )
  507. )
  508. events = []
  509. for result in res:
  510. self.assertNotIn(event3.event_id, result)
  511. events.append(result.get("event_id"))
  512. for event, _ in processed_events_and_context:
  513. if event.is_state():
  514. self.assertIn(event.event_id, events)
  515. # check that each unique state has state group in state_groups_state and that the
  516. # type/state key is correct, and check that each state event's state group
  517. # has an entry and prev event in state_group_edges
  518. for event, context in processed_events_and_context:
  519. if event.is_state():
  520. state = self.get_success(
  521. self.store.db_pool.simple_select_list(
  522. table="state_groups_state",
  523. keyvalues={"state_group": context.state_group_after_event},
  524. retcols=("type", "state_key"),
  525. )
  526. )
  527. self.assertEqual(event.type, state[0].get("type"))
  528. self.assertEqual(event.state_key, state[0].get("state_key"))
  529. groups = self.get_success(
  530. self.store.db_pool.simple_select_list(
  531. table="state_group_edges",
  532. keyvalues={"state_group": str(context.state_group_after_event)},
  533. retcols=("*",),
  534. )
  535. )
  536. self.assertEqual(
  537. context.state_group_before_event, groups[0].get("prev_state_group")
  538. )