You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

713 lines
24 KiB

  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Optional, Set, Union, cast
  15. import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. import synapse.rest.admin
  18. from synapse.api.constants import EventTypes, HistoryVisibility, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.appservice import ApplicationService
  21. from synapse.events import FrozenEvent
  22. from synapse.push.bulk_push_rule_evaluator import _flatten_dict
  23. from synapse.push.httppusher import tweaks_for_actions
  24. from synapse.rest import admin
  25. from synapse.rest.client import login, register, room
  26. from synapse.server import HomeServer
  27. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  28. from synapse.synapse_rust.push import PushRuleEvaluator
  29. from synapse.types import JsonDict, JsonMapping, UserID
  30. from synapse.util import Clock
  31. from tests import unittest
  32. from tests.test_utils.event_injection import create_event, inject_member_event
  33. class PushRuleEvaluatorTestCase(unittest.TestCase):
  34. def _get_evaluator(
  35. self,
  36. content: JsonMapping,
  37. *,
  38. user_mentions: Optional[Set[str]] = None,
  39. room_mention: bool = False,
  40. related_events: Optional[JsonDict] = None,
  41. ) -> PushRuleEvaluator:
  42. event = FrozenEvent(
  43. {
  44. "event_id": "$event_id",
  45. "type": "m.room.history_visibility",
  46. "sender": "@user:test",
  47. "state_key": "",
  48. "room_id": "#room:test",
  49. "content": content,
  50. },
  51. RoomVersions.V1,
  52. )
  53. room_member_count = 0
  54. sender_power_level = 0
  55. power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
  56. return PushRuleEvaluator(
  57. _flatten_dict(event),
  58. user_mentions or set(),
  59. room_mention,
  60. room_member_count,
  61. sender_power_level,
  62. cast(Dict[str, int], power_levels.get("notifications", {})),
  63. {} if related_events is None else related_events,
  64. related_event_match_enabled=True,
  65. room_version_feature_flags=event.room_version.msc3931_push_features,
  66. msc3931_enabled=True,
  67. )
  68. def test_display_name(self) -> None:
  69. """Check for a matching display name in the body of the event."""
  70. evaluator = self._get_evaluator({"body": "foo bar baz"})
  71. condition = {"kind": "contains_display_name"}
  72. # Blank names are skipped.
  73. self.assertFalse(evaluator.matches(condition, "@user:test", ""))
  74. # Check a display name that doesn't match.
  75. self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
  76. # Check a display name which matches.
  77. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  78. # A display name that matches, but not a full word does not result in a match.
  79. self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
  80. # A display name should not be interpreted as a regular expression.
  81. self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
  82. # A display name with spaces should work fine.
  83. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
  84. def test_user_mentions(self) -> None:
  85. """Check for user mentions."""
  86. condition = {"kind": "org.matrix.msc3952.is_user_mention"}
  87. # No mentions shouldn't match.
  88. evaluator = self._get_evaluator({})
  89. self.assertFalse(evaluator.matches(condition, "@user:test", None))
  90. # An empty set shouldn't match
  91. evaluator = self._get_evaluator({}, user_mentions=set())
  92. self.assertFalse(evaluator.matches(condition, "@user:test", None))
  93. # The Matrix ID appearing anywhere in the mentions list should match
  94. evaluator = self._get_evaluator({}, user_mentions={"@user:test"})
  95. self.assertTrue(evaluator.matches(condition, "@user:test", None))
  96. evaluator = self._get_evaluator(
  97. {}, user_mentions={"@another:test", "@user:test"}
  98. )
  99. self.assertTrue(evaluator.matches(condition, "@user:test", None))
  100. # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
  101. # since the BulkPushRuleEvaluator is what handles data sanitisation.
  102. def test_room_mentions(self) -> None:
  103. """Check for room mentions."""
  104. condition = {"kind": "org.matrix.msc3952.is_room_mention"}
  105. # No room mention shouldn't match.
  106. evaluator = self._get_evaluator({})
  107. self.assertFalse(evaluator.matches(condition, None, None))
  108. # Room mention should match.
  109. evaluator = self._get_evaluator({}, room_mention=True)
  110. self.assertTrue(evaluator.matches(condition, None, None))
  111. # A room mention and user mention is valid.
  112. evaluator = self._get_evaluator(
  113. {}, user_mentions={"@another:test"}, room_mention=True
  114. )
  115. self.assertTrue(evaluator.matches(condition, None, None))
  116. # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
  117. # since the BulkPushRuleEvaluator is what handles data sanitisation.
  118. def _assert_matches(
  119. self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
  120. ) -> None:
  121. evaluator = self._get_evaluator(content)
  122. self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
  123. def _assert_not_matches(
  124. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  125. ) -> None:
  126. evaluator = self._get_evaluator(content)
  127. self.assertFalse(
  128. evaluator.matches(condition, "@user:test", "display_name"), msg
  129. )
  130. def test_event_match_body(self) -> None:
  131. """Check that event_match conditions on content.body work as expected"""
  132. # if the key is `content.body`, the pattern matches substrings.
  133. # non-wildcards should match
  134. condition = {
  135. "kind": "event_match",
  136. "key": "content.body",
  137. "pattern": "foobaz",
  138. }
  139. self._assert_matches(
  140. condition,
  141. {"body": "aaa FoobaZ zzz"},
  142. "patterns should match and be case-insensitive",
  143. )
  144. self._assert_not_matches(
  145. condition,
  146. {"body": "aa xFoobaZ yy"},
  147. "pattern should only match at word boundaries",
  148. )
  149. self._assert_not_matches(
  150. condition,
  151. {"body": "aa foobazx yy"},
  152. "pattern should only match at word boundaries",
  153. )
  154. # wildcards should match
  155. condition = {
  156. "kind": "event_match",
  157. "key": "content.body",
  158. "pattern": "f?o*baz",
  159. }
  160. self._assert_matches(
  161. condition,
  162. {"body": "aaa FoobarbaZ zzz"},
  163. "* should match string and pattern should be case-insensitive",
  164. )
  165. self._assert_matches(
  166. condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
  167. )
  168. self._assert_not_matches(
  169. condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
  170. )
  171. self._assert_not_matches(
  172. condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
  173. )
  174. self._assert_not_matches(
  175. condition,
  176. {"body": "aa xfooxbaz yy"},
  177. "pattern should only match at word boundaries",
  178. )
  179. self._assert_not_matches(
  180. condition,
  181. {"body": "aa fooxbazx yy"},
  182. "pattern should only match at word boundaries",
  183. )
  184. # test backslashes
  185. condition = {
  186. "kind": "event_match",
  187. "key": "content.body",
  188. "pattern": r"f\oobaz",
  189. }
  190. self._assert_matches(
  191. condition,
  192. {"body": r"F\oobaz"},
  193. "backslash should match itself",
  194. )
  195. condition = {
  196. "kind": "event_match",
  197. "key": "content.body",
  198. "pattern": r"f\?obaz",
  199. }
  200. self._assert_matches(
  201. condition,
  202. {"body": r"F\oobaz"},
  203. r"? after \ should match any character",
  204. )
  205. def test_event_match_non_body(self) -> None:
  206. """Check that event_match conditions on other keys work as expected"""
  207. # if the key is anything other than 'content.body', the pattern must match the
  208. # whole value.
  209. # non-wildcards should match
  210. condition = {
  211. "kind": "event_match",
  212. "key": "content.value",
  213. "pattern": "foobaz",
  214. }
  215. self._assert_matches(
  216. condition,
  217. {"value": "FoobaZ"},
  218. "patterns should match and be case-insensitive",
  219. )
  220. self._assert_not_matches(
  221. condition,
  222. {"value": "xFoobaZ"},
  223. "pattern should only match at the start/end of the value",
  224. )
  225. self._assert_not_matches(
  226. condition,
  227. {"value": "FoobaZz"},
  228. "pattern should only match at the start/end of the value",
  229. )
  230. # it should work on frozendicts too
  231. self._assert_matches(
  232. condition,
  233. frozendict.frozendict({"value": "FoobaZ"}),
  234. "patterns should match on frozendicts",
  235. )
  236. # wildcards should match
  237. condition = {
  238. "kind": "event_match",
  239. "key": "content.value",
  240. "pattern": "f?o*baz",
  241. }
  242. self._assert_matches(
  243. condition,
  244. {"value": "FoobarbaZ"},
  245. "* should match string and pattern should be case-insensitive",
  246. )
  247. self._assert_matches(
  248. condition, {"value": "foobaz"}, "* should match 0 characters"
  249. )
  250. self._assert_not_matches(
  251. condition, {"value": "fobbaz"}, "? should not match 0 characters"
  252. )
  253. self._assert_not_matches(
  254. condition, {"value": "fiiobaz"}, "? should not match 2 characters"
  255. )
  256. self._assert_not_matches(
  257. condition,
  258. {"value": "xfooxbaz"},
  259. "pattern should only match at the start/end of the value",
  260. )
  261. self._assert_not_matches(
  262. condition,
  263. {"value": "fooxbazx"},
  264. "pattern should only match at the start/end of the value",
  265. )
  266. self._assert_not_matches(
  267. condition,
  268. {"value": "x\nfooxbaz"},
  269. "pattern should not match after a newline",
  270. )
  271. self._assert_not_matches(
  272. condition,
  273. {"value": "fooxbaz\nx"},
  274. "pattern should not match before a newline",
  275. )
  276. def test_no_body(self) -> None:
  277. """Not having a body shouldn't break the evaluator."""
  278. evaluator = self._get_evaluator({})
  279. condition = {
  280. "kind": "contains_display_name",
  281. }
  282. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  283. def test_invalid_body(self) -> None:
  284. """A non-string body should not break the evaluator."""
  285. condition = {
  286. "kind": "contains_display_name",
  287. }
  288. for body in (1, True, {"foo": "bar"}):
  289. evaluator = self._get_evaluator({"body": body})
  290. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  291. def test_tweaks_for_actions(self) -> None:
  292. """
  293. This tests the behaviour of tweaks_for_actions.
  294. """
  295. actions: List[Union[Dict[str, str], str]] = [
  296. {"set_tweak": "sound", "value": "default"},
  297. {"set_tweak": "highlight"},
  298. "notify",
  299. ]
  300. self.assertEqual(
  301. tweaks_for_actions(actions),
  302. {"sound": "default", "highlight": True},
  303. )
  304. def test_related_event_match(self) -> None:
  305. evaluator = self._get_evaluator(
  306. {
  307. "m.relates_to": {
  308. "event_id": "$parent_event_id",
  309. "key": "😀",
  310. "rel_type": "m.annotation",
  311. "m.in_reply_to": {
  312. "event_id": "$parent_event_id",
  313. },
  314. }
  315. },
  316. related_events={
  317. "m.in_reply_to": {
  318. "event_id": "$parent_event_id",
  319. "type": "m.room.message",
  320. "sender": "@other_user:test",
  321. "room_id": "!room:test",
  322. "content.msgtype": "m.text",
  323. "content.body": "Original message",
  324. },
  325. "m.annotation": {
  326. "event_id": "$parent_event_id",
  327. "type": "m.room.message",
  328. "sender": "@other_user:test",
  329. "room_id": "!room:test",
  330. "content.msgtype": "m.text",
  331. "content.body": "Original message",
  332. },
  333. },
  334. )
  335. self.assertTrue(
  336. evaluator.matches(
  337. {
  338. "kind": "im.nheko.msc3664.related_event_match",
  339. "key": "sender",
  340. "rel_type": "m.in_reply_to",
  341. "pattern": "@other_user:test",
  342. },
  343. "@user:test",
  344. "display_name",
  345. )
  346. )
  347. self.assertFalse(
  348. evaluator.matches(
  349. {
  350. "kind": "im.nheko.msc3664.related_event_match",
  351. "key": "sender",
  352. "rel_type": "m.in_reply_to",
  353. "pattern": "@user:test",
  354. },
  355. "@other_user:test",
  356. "display_name",
  357. )
  358. )
  359. self.assertTrue(
  360. evaluator.matches(
  361. {
  362. "kind": "im.nheko.msc3664.related_event_match",
  363. "key": "sender",
  364. "rel_type": "m.annotation",
  365. "pattern": "@other_user:test",
  366. },
  367. "@other_user:test",
  368. "display_name",
  369. )
  370. )
  371. self.assertFalse(
  372. evaluator.matches(
  373. {
  374. "kind": "im.nheko.msc3664.related_event_match",
  375. "key": "sender",
  376. "rel_type": "m.in_reply_to",
  377. },
  378. "@user:test",
  379. "display_name",
  380. )
  381. )
  382. self.assertTrue(
  383. evaluator.matches(
  384. {
  385. "kind": "im.nheko.msc3664.related_event_match",
  386. "rel_type": "m.in_reply_to",
  387. },
  388. "@user:test",
  389. "display_name",
  390. )
  391. )
  392. self.assertFalse(
  393. evaluator.matches(
  394. {
  395. "kind": "im.nheko.msc3664.related_event_match",
  396. "rel_type": "m.replace",
  397. },
  398. "@other_user:test",
  399. "display_name",
  400. )
  401. )
  402. def test_related_event_match_with_fallback(self) -> None:
  403. evaluator = self._get_evaluator(
  404. {
  405. "m.relates_to": {
  406. "event_id": "$parent_event_id",
  407. "key": "😀",
  408. "rel_type": "m.thread",
  409. "is_falling_back": True,
  410. "m.in_reply_to": {
  411. "event_id": "$parent_event_id",
  412. },
  413. }
  414. },
  415. related_events={
  416. "m.in_reply_to": {
  417. "event_id": "$parent_event_id",
  418. "type": "m.room.message",
  419. "sender": "@other_user:test",
  420. "room_id": "!room:test",
  421. "content.msgtype": "m.text",
  422. "content.body": "Original message",
  423. "im.vector.is_falling_back": "",
  424. },
  425. "m.thread": {
  426. "event_id": "$parent_event_id",
  427. "type": "m.room.message",
  428. "sender": "@other_user:test",
  429. "room_id": "!room:test",
  430. "content.msgtype": "m.text",
  431. "content.body": "Original message",
  432. },
  433. },
  434. )
  435. self.assertTrue(
  436. evaluator.matches(
  437. {
  438. "kind": "im.nheko.msc3664.related_event_match",
  439. "key": "sender",
  440. "rel_type": "m.in_reply_to",
  441. "pattern": "@other_user:test",
  442. "include_fallbacks": True,
  443. },
  444. "@user:test",
  445. "display_name",
  446. )
  447. )
  448. self.assertFalse(
  449. evaluator.matches(
  450. {
  451. "kind": "im.nheko.msc3664.related_event_match",
  452. "key": "sender",
  453. "rel_type": "m.in_reply_to",
  454. "pattern": "@other_user:test",
  455. "include_fallbacks": False,
  456. },
  457. "@user:test",
  458. "display_name",
  459. )
  460. )
  461. self.assertFalse(
  462. evaluator.matches(
  463. {
  464. "kind": "im.nheko.msc3664.related_event_match",
  465. "key": "sender",
  466. "rel_type": "m.in_reply_to",
  467. "pattern": "@other_user:test",
  468. },
  469. "@user:test",
  470. "display_name",
  471. )
  472. )
  473. def test_related_event_match_no_related_event(self) -> None:
  474. evaluator = self._get_evaluator(
  475. {"msgtype": "m.text", "body": "Message without related event"}
  476. )
  477. self.assertFalse(
  478. evaluator.matches(
  479. {
  480. "kind": "im.nheko.msc3664.related_event_match",
  481. "key": "sender",
  482. "rel_type": "m.in_reply_to",
  483. "pattern": "@other_user:test",
  484. },
  485. "@user:test",
  486. "display_name",
  487. )
  488. )
  489. self.assertFalse(
  490. evaluator.matches(
  491. {
  492. "kind": "im.nheko.msc3664.related_event_match",
  493. "key": "sender",
  494. "rel_type": "m.in_reply_to",
  495. },
  496. "@user:test",
  497. "display_name",
  498. )
  499. )
  500. self.assertFalse(
  501. evaluator.matches(
  502. {
  503. "kind": "im.nheko.msc3664.related_event_match",
  504. "rel_type": "m.in_reply_to",
  505. },
  506. "@user:test",
  507. "display_name",
  508. )
  509. )
  510. class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
  511. """Tests for the bulk push rule evaluator"""
  512. servlets = [
  513. synapse.rest.admin.register_servlets_for_client_rest_resource,
  514. login.register_servlets,
  515. register.register_servlets,
  516. room.register_servlets,
  517. ]
  518. def prepare(
  519. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  520. ) -> None:
  521. # Define an application service so that we can register appservice users
  522. self._service_token = "some_token"
  523. self._service = ApplicationService(
  524. self._service_token,
  525. "as1",
  526. "@as.sender:test",
  527. namespaces={
  528. "users": [
  529. {"regex": "@_as_.*:test", "exclusive": True},
  530. {"regex": "@as.sender:test", "exclusive": True},
  531. ]
  532. },
  533. msc3202_transaction_extensions=True,
  534. )
  535. self.hs.get_datastores().main.services_cache = [self._service]
  536. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  537. [self._service]
  538. )
  539. self._as_user, _ = self.register_appservice_user(
  540. "_as_user", self._service_token
  541. )
  542. self.evaluator = self.hs.get_bulk_push_rule_evaluator()
  543. def test_ignore_appservice_users(self) -> None:
  544. "Test that we don't generate push for appservice users"
  545. user_id = self.register_user("user", "pass")
  546. token = self.login("user", "pass")
  547. room_id = self.helper.create_room_as(user_id, tok=token)
  548. self.get_success(
  549. inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
  550. )
  551. event, context = self.get_success(
  552. create_event(
  553. self.hs,
  554. type=EventTypes.Message,
  555. room_id=room_id,
  556. sender=user_id,
  557. content={"body": "test", "msgtype": "m.text"},
  558. )
  559. )
  560. # Assert the returned push rules do not contain the app service user
  561. rules = self.get_success(self.evaluator._get_rules_for_event(event))
  562. self.assertTrue(self._as_user not in rules)
  563. # Assert that no push actions have been added to the staging table (the
  564. # sender should not be pushed for the event)
  565. users_with_push_actions = self.get_success(
  566. self.hs.get_datastores().main.db_pool.simple_select_onecol(
  567. table="event_push_actions_staging",
  568. keyvalues={"event_id": event.event_id},
  569. retcol="user_id",
  570. desc="test_ignore_appservice_users",
  571. )
  572. )
  573. self.assertEqual(len(users_with_push_actions), 0)
  574. class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
  575. servlets = [
  576. admin.register_servlets,
  577. login.register_servlets,
  578. room.register_servlets,
  579. ]
  580. def prepare(
  581. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  582. ) -> None:
  583. self.main_store = homeserver.get_datastores().main
  584. self.user_id1 = self.register_user("user1", "password")
  585. self.tok1 = self.login(self.user_id1, "password")
  586. self.user_id2 = self.register_user("user2", "password")
  587. self.tok2 = self.login(self.user_id2, "password")
  588. self.room_id = self.helper.create_room_as(tok=self.tok1)
  589. # We want to test history visibility works correctly.
  590. self.helper.send_state(
  591. self.room_id,
  592. EventTypes.RoomHistoryVisibility,
  593. {"history_visibility": HistoryVisibility.JOINED},
  594. tok=self.tok1,
  595. )
  596. def get_notif_count(self, user_id: str) -> int:
  597. return self.get_success(
  598. self.main_store.db_pool.simple_select_one_onecol(
  599. table="event_push_actions",
  600. keyvalues={"user_id": user_id},
  601. retcol="COALESCE(SUM(notif), 0)",
  602. desc="get_staging_notif_count",
  603. )
  604. )
  605. def test_plain_message(self) -> None:
  606. """Test that sending a normal message in a room will trigger a
  607. notification
  608. """
  609. # Have user2 join the room and cle
  610. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  611. # They start off with no notifications, but get them when messages are
  612. # sent.
  613. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  614. user1 = UserID.from_string(self.user_id1)
  615. self.create_and_send_event(self.room_id, user1)
  616. self.assertEqual(self.get_notif_count(self.user_id2), 1)
  617. def test_delayed_message(self) -> None:
  618. """Test that a delayed message that was from before a user joined
  619. doesn't cause a notification for the joined user.
  620. """
  621. user1 = UserID.from_string(self.user_id1)
  622. # Send a message before user2 joins
  623. event_id1 = self.create_and_send_event(self.room_id, user1)
  624. # Have user2 join the room
  625. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  626. # They start off with no notifications
  627. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  628. # Send another message that references the event before the join to
  629. # simulate a "delayed" event
  630. self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
  631. # user2 should not be notified about it, because they can't see it.
  632. self.assertEqual(self.get_notif_count(self.user_id2), 0)