You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

628 lines
21 KiB

  1. from immutabledict import immutabledict
  2. from synapse.api.constants import EventTypes
  3. from synapse.types.state import StateFilter
  4. from tests.unittest import TestCase
  5. class StateFilterDifferenceTestCase(TestCase):
  6. def assert_difference(
  7. self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
  8. ) -> None:
  9. self.assertEqual(
  10. minuend.approx_difference(subtrahend),
  11. expected,
  12. f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
  13. )
  14. def test_state_filter_difference_no_include_other_minus_no_include_other(
  15. self,
  16. ) -> None:
  17. """
  18. Tests the StateFilter.approx_difference method
  19. where, in a.approx_difference(b), both a and b do not have the
  20. include_others flag set.
  21. """
  22. # (wildcard on state keys) - (wildcard on state keys):
  23. self.assert_difference(
  24. StateFilter.freeze(
  25. {EventTypes.Member: None, EventTypes.Create: None},
  26. include_others=False,
  27. ),
  28. StateFilter.freeze(
  29. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  30. include_others=False,
  31. ),
  32. StateFilter.freeze({EventTypes.Create: None}, include_others=False),
  33. )
  34. # (wildcard on state keys) - (specific state keys)
  35. # This one is an over-approximation because we can't represent
  36. # 'all state keys except a few named examples'
  37. self.assert_difference(
  38. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  39. StateFilter.freeze(
  40. {EventTypes.Member: {"@wombat:spqr"}},
  41. include_others=False,
  42. ),
  43. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  44. )
  45. # (wildcard on state keys) - (no state keys)
  46. self.assert_difference(
  47. StateFilter.freeze(
  48. {EventTypes.Member: None},
  49. include_others=False,
  50. ),
  51. StateFilter.freeze(
  52. {
  53. EventTypes.Member: set(),
  54. },
  55. include_others=False,
  56. ),
  57. StateFilter.freeze(
  58. {EventTypes.Member: None},
  59. include_others=False,
  60. ),
  61. )
  62. # (specific state keys) - (wildcard on state keys):
  63. self.assert_difference(
  64. StateFilter.freeze(
  65. {
  66. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  67. EventTypes.CanonicalAlias: {""},
  68. },
  69. include_others=False,
  70. ),
  71. StateFilter.freeze(
  72. {EventTypes.Member: None},
  73. include_others=False,
  74. ),
  75. StateFilter.freeze(
  76. {EventTypes.CanonicalAlias: {""}},
  77. include_others=False,
  78. ),
  79. )
  80. # (specific state keys) - (specific state keys)
  81. self.assert_difference(
  82. StateFilter.freeze(
  83. {
  84. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  85. EventTypes.CanonicalAlias: {""},
  86. },
  87. include_others=False,
  88. ),
  89. StateFilter.freeze(
  90. {
  91. EventTypes.Member: {"@wombat:spqr"},
  92. },
  93. include_others=False,
  94. ),
  95. StateFilter.freeze(
  96. {
  97. EventTypes.Member: {"@spqr:spqr"},
  98. EventTypes.CanonicalAlias: {""},
  99. },
  100. include_others=False,
  101. ),
  102. )
  103. # (specific state keys) - (no state keys)
  104. self.assert_difference(
  105. StateFilter.freeze(
  106. {
  107. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  108. EventTypes.CanonicalAlias: {""},
  109. },
  110. include_others=False,
  111. ),
  112. StateFilter.freeze(
  113. {
  114. EventTypes.Member: set(),
  115. },
  116. include_others=False,
  117. ),
  118. StateFilter.freeze(
  119. {
  120. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  121. EventTypes.CanonicalAlias: {""},
  122. },
  123. include_others=False,
  124. ),
  125. )
  126. def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
  127. """
  128. Tests the StateFilter.approx_difference method
  129. where, in a.approx_difference(b), only a has the include_others flag set.
  130. """
  131. # (wildcard on state keys) - (wildcard on state keys):
  132. self.assert_difference(
  133. StateFilter.freeze(
  134. {EventTypes.Member: None, EventTypes.Create: None},
  135. include_others=True,
  136. ),
  137. StateFilter.freeze(
  138. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  139. include_others=False,
  140. ),
  141. StateFilter.freeze(
  142. {
  143. EventTypes.Create: None,
  144. EventTypes.Member: set(),
  145. EventTypes.CanonicalAlias: set(),
  146. },
  147. include_others=True,
  148. ),
  149. )
  150. # (wildcard on state keys) - (specific state keys)
  151. # This one is an over-approximation because we can't represent
  152. # 'all state keys except a few named examples'
  153. # This also shows that the resultant state filter is normalised.
  154. self.assert_difference(
  155. StateFilter.freeze({EventTypes.Member: None}, include_others=True),
  156. StateFilter.freeze(
  157. {
  158. EventTypes.Member: {"@wombat:spqr"},
  159. EventTypes.Create: {""},
  160. },
  161. include_others=False,
  162. ),
  163. StateFilter(types=immutabledict(), include_others=True),
  164. )
  165. # (wildcard on state keys) - (no state keys)
  166. self.assert_difference(
  167. StateFilter.freeze(
  168. {EventTypes.Member: None},
  169. include_others=True,
  170. ),
  171. StateFilter.freeze(
  172. {
  173. EventTypes.Member: set(),
  174. },
  175. include_others=False,
  176. ),
  177. StateFilter(
  178. types=immutabledict(),
  179. include_others=True,
  180. ),
  181. )
  182. # (specific state keys) - (wildcard on state keys):
  183. self.assert_difference(
  184. StateFilter.freeze(
  185. {
  186. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  187. EventTypes.CanonicalAlias: {""},
  188. },
  189. include_others=True,
  190. ),
  191. StateFilter.freeze(
  192. {EventTypes.Member: None},
  193. include_others=False,
  194. ),
  195. StateFilter.freeze(
  196. {
  197. EventTypes.CanonicalAlias: {""},
  198. EventTypes.Member: set(),
  199. },
  200. include_others=True,
  201. ),
  202. )
  203. # (specific state keys) - (specific state keys)
  204. self.assert_difference(
  205. StateFilter.freeze(
  206. {
  207. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  208. EventTypes.CanonicalAlias: {""},
  209. },
  210. include_others=True,
  211. ),
  212. StateFilter.freeze(
  213. {
  214. EventTypes.Member: {"@wombat:spqr"},
  215. },
  216. include_others=False,
  217. ),
  218. StateFilter.freeze(
  219. {
  220. EventTypes.Member: {"@spqr:spqr"},
  221. EventTypes.CanonicalAlias: {""},
  222. },
  223. include_others=True,
  224. ),
  225. )
  226. # (specific state keys) - (no state keys)
  227. self.assert_difference(
  228. StateFilter.freeze(
  229. {
  230. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  231. EventTypes.CanonicalAlias: {""},
  232. },
  233. include_others=True,
  234. ),
  235. StateFilter.freeze(
  236. {
  237. EventTypes.Member: set(),
  238. },
  239. include_others=False,
  240. ),
  241. StateFilter.freeze(
  242. {
  243. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  244. EventTypes.CanonicalAlias: {""},
  245. },
  246. include_others=True,
  247. ),
  248. )
  249. def test_state_filter_difference_include_other_minus_include_other(self) -> None:
  250. """
  251. Tests the StateFilter.approx_difference method
  252. where, in a.approx_difference(b), both a and b have the include_others
  253. flag set.
  254. """
  255. # (wildcard on state keys) - (wildcard on state keys):
  256. self.assert_difference(
  257. StateFilter.freeze(
  258. {EventTypes.Member: None, EventTypes.Create: None},
  259. include_others=True,
  260. ),
  261. StateFilter.freeze(
  262. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  263. include_others=True,
  264. ),
  265. StateFilter(types=immutabledict(), include_others=False),
  266. )
  267. # (wildcard on state keys) - (specific state keys)
  268. # This one is an over-approximation because we can't represent
  269. # 'all state keys except a few named examples'
  270. self.assert_difference(
  271. StateFilter.freeze({EventTypes.Member: None}, include_others=True),
  272. StateFilter.freeze(
  273. {
  274. EventTypes.Member: {"@wombat:spqr"},
  275. EventTypes.CanonicalAlias: {""},
  276. },
  277. include_others=True,
  278. ),
  279. StateFilter.freeze(
  280. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  281. include_others=False,
  282. ),
  283. )
  284. # (wildcard on state keys) - (no state keys)
  285. self.assert_difference(
  286. StateFilter.freeze(
  287. {EventTypes.Member: None},
  288. include_others=True,
  289. ),
  290. StateFilter.freeze(
  291. {
  292. EventTypes.Member: set(),
  293. },
  294. include_others=True,
  295. ),
  296. StateFilter.freeze(
  297. {EventTypes.Member: None},
  298. include_others=False,
  299. ),
  300. )
  301. # (specific state keys) - (wildcard on state keys):
  302. self.assert_difference(
  303. StateFilter.freeze(
  304. {
  305. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  306. EventTypes.CanonicalAlias: {""},
  307. },
  308. include_others=True,
  309. ),
  310. StateFilter.freeze(
  311. {EventTypes.Member: None},
  312. include_others=True,
  313. ),
  314. StateFilter(
  315. types=immutabledict(),
  316. include_others=False,
  317. ),
  318. )
  319. # (specific state keys) - (specific state keys)
  320. # This one is an over-approximation because we can't represent
  321. # 'all state keys except a few named examples'
  322. self.assert_difference(
  323. StateFilter.freeze(
  324. {
  325. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  326. EventTypes.CanonicalAlias: {""},
  327. EventTypes.Create: {""},
  328. },
  329. include_others=True,
  330. ),
  331. StateFilter.freeze(
  332. {
  333. EventTypes.Member: {"@wombat:spqr"},
  334. EventTypes.Create: set(),
  335. },
  336. include_others=True,
  337. ),
  338. StateFilter.freeze(
  339. {
  340. EventTypes.Member: {"@spqr:spqr"},
  341. EventTypes.Create: {""},
  342. },
  343. include_others=False,
  344. ),
  345. )
  346. # (specific state keys) - (no state keys)
  347. self.assert_difference(
  348. StateFilter.freeze(
  349. {
  350. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  351. EventTypes.CanonicalAlias: {""},
  352. },
  353. include_others=True,
  354. ),
  355. StateFilter.freeze(
  356. {
  357. EventTypes.Member: set(),
  358. },
  359. include_others=True,
  360. ),
  361. StateFilter.freeze(
  362. {
  363. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  364. },
  365. include_others=False,
  366. ),
  367. )
  368. def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
  369. """
  370. Tests the StateFilter.approx_difference method
  371. where, in a.approx_difference(b), only b has the include_others flag set.
  372. """
  373. # (wildcard on state keys) - (wildcard on state keys):
  374. self.assert_difference(
  375. StateFilter.freeze(
  376. {EventTypes.Member: None, EventTypes.Create: None},
  377. include_others=False,
  378. ),
  379. StateFilter.freeze(
  380. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  381. include_others=True,
  382. ),
  383. StateFilter(types=immutabledict(), include_others=False),
  384. )
  385. # (wildcard on state keys) - (specific state keys)
  386. # This one is an over-approximation because we can't represent
  387. # 'all state keys except a few named examples'
  388. self.assert_difference(
  389. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  390. StateFilter.freeze(
  391. {EventTypes.Member: {"@wombat:spqr"}},
  392. include_others=True,
  393. ),
  394. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  395. )
  396. # (wildcard on state keys) - (no state keys)
  397. self.assert_difference(
  398. StateFilter.freeze(
  399. {EventTypes.Member: None},
  400. include_others=False,
  401. ),
  402. StateFilter.freeze(
  403. {
  404. EventTypes.Member: set(),
  405. },
  406. include_others=True,
  407. ),
  408. StateFilter.freeze(
  409. {EventTypes.Member: None},
  410. include_others=False,
  411. ),
  412. )
  413. # (specific state keys) - (wildcard on state keys):
  414. self.assert_difference(
  415. StateFilter.freeze(
  416. {
  417. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  418. EventTypes.CanonicalAlias: {""},
  419. },
  420. include_others=False,
  421. ),
  422. StateFilter.freeze(
  423. {EventTypes.Member: None},
  424. include_others=True,
  425. ),
  426. StateFilter(
  427. types=immutabledict(),
  428. include_others=False,
  429. ),
  430. )
  431. # (specific state keys) - (specific state keys)
  432. # This one is an over-approximation because we can't represent
  433. # 'all state keys except a few named examples'
  434. self.assert_difference(
  435. StateFilter.freeze(
  436. {
  437. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  438. EventTypes.CanonicalAlias: {""},
  439. },
  440. include_others=False,
  441. ),
  442. StateFilter.freeze(
  443. {
  444. EventTypes.Member: {"@wombat:spqr"},
  445. },
  446. include_others=True,
  447. ),
  448. StateFilter.freeze(
  449. {
  450. EventTypes.Member: {"@spqr:spqr"},
  451. },
  452. include_others=False,
  453. ),
  454. )
  455. # (specific state keys) - (no state keys)
  456. self.assert_difference(
  457. StateFilter.freeze(
  458. {
  459. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  460. EventTypes.CanonicalAlias: {""},
  461. },
  462. include_others=False,
  463. ),
  464. StateFilter.freeze(
  465. {
  466. EventTypes.Member: set(),
  467. },
  468. include_others=True,
  469. ),
  470. StateFilter.freeze(
  471. {
  472. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  473. },
  474. include_others=False,
  475. ),
  476. )
  477. def test_state_filter_difference_simple_cases(self) -> None:
  478. """
  479. Tests some very simple cases of the StateFilter approx_difference,
  480. that are not explicitly tested by the more in-depth tests.
  481. """
  482. self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
  483. self.assert_difference(
  484. StateFilter.all(),
  485. StateFilter.none(),
  486. StateFilter.all(),
  487. )
  488. class StateFilterTestCase(TestCase):
  489. def test_return_expanded(self) -> None:
  490. """
  491. Tests the behaviour of the return_expanded() function that expands
  492. StateFilters to include more state types (for the sake of cache hit rate).
  493. """
  494. self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
  495. self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
  496. # Concrete-only state filters stay the same
  497. # (Case: mixed filter)
  498. self.assertEqual(
  499. StateFilter.freeze(
  500. {
  501. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  502. "some.other.state.type": {""},
  503. },
  504. include_others=False,
  505. ).return_expanded(),
  506. StateFilter.freeze(
  507. {
  508. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  509. "some.other.state.type": {""},
  510. },
  511. include_others=False,
  512. ),
  513. )
  514. # Concrete-only state filters stay the same
  515. # (Case: non-member-only filter)
  516. self.assertEqual(
  517. StateFilter.freeze(
  518. {"some.other.state.type": {""}}, include_others=False
  519. ).return_expanded(),
  520. StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
  521. )
  522. # Concrete-only state filters stay the same
  523. # (Case: member-only filter)
  524. self.assertEqual(
  525. StateFilter.freeze(
  526. {
  527. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  528. },
  529. include_others=False,
  530. ).return_expanded(),
  531. StateFilter.freeze(
  532. {
  533. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  534. },
  535. include_others=False,
  536. ),
  537. )
  538. # Wildcard member-only state filters stay the same
  539. self.assertEqual(
  540. StateFilter.freeze(
  541. {EventTypes.Member: None},
  542. include_others=False,
  543. ).return_expanded(),
  544. StateFilter.freeze(
  545. {EventTypes.Member: None},
  546. include_others=False,
  547. ),
  548. )
  549. # If there is a wildcard in the non-member portion of the filter,
  550. # it's expanded to include ALL non-member events.
  551. # (Case: mixed filter)
  552. self.assertEqual(
  553. StateFilter.freeze(
  554. {
  555. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  556. "some.other.state.type": None,
  557. },
  558. include_others=False,
  559. ).return_expanded(),
  560. StateFilter.freeze(
  561. {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
  562. include_others=True,
  563. ),
  564. )
  565. # If there is a wildcard in the non-member portion of the filter,
  566. # it's expanded to include ALL non-member events.
  567. # (Case: non-member-only filter)
  568. self.assertEqual(
  569. StateFilter.freeze(
  570. {
  571. "some.other.state.type": None,
  572. },
  573. include_others=False,
  574. ).return_expanded(),
  575. StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
  576. )
  577. self.assertEqual(
  578. StateFilter.freeze(
  579. {
  580. "some.other.state.type": None,
  581. "yet.another.state.type": {"wombat"},
  582. },
  583. include_others=False,
  584. ).return_expanded(),
  585. StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
  586. )