You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

440 lines
15 KiB

  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Set, Tuple, Union
  15. import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. import synapse.rest.admin
  18. from synapse.api.constants import EventTypes, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.appservice import ApplicationService
  21. from synapse.events import FrozenEvent
  22. from synapse.push.bulk_push_rule_evaluator import _flatten_dict
  23. from synapse.push.httppusher import tweaks_for_actions
  24. from synapse.rest.client import login, register, room
  25. from synapse.server import HomeServer
  26. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  27. from synapse.synapse_rust.push import PushRuleEvaluator
  28. from synapse.types import JsonDict
  29. from synapse.util import Clock
  30. from tests import unittest
  31. from tests.test_utils.event_injection import create_event, inject_member_event
  32. class PushRuleEvaluatorTestCase(unittest.TestCase):
  33. def _get_evaluator(
  34. self,
  35. content: JsonDict,
  36. relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
  37. relations_match_enabled: bool = False,
  38. ) -> PushRuleEvaluator:
  39. event = FrozenEvent(
  40. {
  41. "event_id": "$event_id",
  42. "type": "m.room.history_visibility",
  43. "sender": "@user:test",
  44. "state_key": "",
  45. "room_id": "#room:test",
  46. "content": content,
  47. },
  48. RoomVersions.V1,
  49. )
  50. room_member_count = 0
  51. sender_power_level = 0
  52. power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
  53. return PushRuleEvaluator(
  54. _flatten_dict(event),
  55. room_member_count,
  56. sender_power_level,
  57. power_levels.get("notifications", {}),
  58. relations or {},
  59. relations_match_enabled,
  60. )
  61. def test_display_name(self) -> None:
  62. """Check for a matching display name in the body of the event."""
  63. evaluator = self._get_evaluator({"body": "foo bar baz"})
  64. condition = {
  65. "kind": "contains_display_name",
  66. }
  67. # Blank names are skipped.
  68. self.assertFalse(evaluator.matches(condition, "@user:test", ""))
  69. # Check a display name that doesn't match.
  70. self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
  71. # Check a display name which matches.
  72. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  73. # A display name that matches, but not a full word does not result in a match.
  74. self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
  75. # A display name should not be interpreted as a regular expression.
  76. self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
  77. # A display name with spaces should work fine.
  78. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
  79. def _assert_matches(
  80. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  81. ) -> None:
  82. evaluator = self._get_evaluator(content)
  83. self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
  84. def _assert_not_matches(
  85. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  86. ) -> None:
  87. evaluator = self._get_evaluator(content)
  88. self.assertFalse(
  89. evaluator.matches(condition, "@user:test", "display_name"), msg
  90. )
  91. def test_event_match_body(self) -> None:
  92. """Check that event_match conditions on content.body work as expected"""
  93. # if the key is `content.body`, the pattern matches substrings.
  94. # non-wildcards should match
  95. condition = {
  96. "kind": "event_match",
  97. "key": "content.body",
  98. "pattern": "foobaz",
  99. }
  100. self._assert_matches(
  101. condition,
  102. {"body": "aaa FoobaZ zzz"},
  103. "patterns should match and be case-insensitive",
  104. )
  105. self._assert_not_matches(
  106. condition,
  107. {"body": "aa xFoobaZ yy"},
  108. "pattern should only match at word boundaries",
  109. )
  110. self._assert_not_matches(
  111. condition,
  112. {"body": "aa foobazx yy"},
  113. "pattern should only match at word boundaries",
  114. )
  115. # wildcards should match
  116. condition = {
  117. "kind": "event_match",
  118. "key": "content.body",
  119. "pattern": "f?o*baz",
  120. }
  121. self._assert_matches(
  122. condition,
  123. {"body": "aaa FoobarbaZ zzz"},
  124. "* should match string and pattern should be case-insensitive",
  125. )
  126. self._assert_matches(
  127. condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
  128. )
  129. self._assert_not_matches(
  130. condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
  131. )
  132. self._assert_not_matches(
  133. condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
  134. )
  135. self._assert_not_matches(
  136. condition,
  137. {"body": "aa xfooxbaz yy"},
  138. "pattern should only match at word boundaries",
  139. )
  140. self._assert_not_matches(
  141. condition,
  142. {"body": "aa fooxbazx yy"},
  143. "pattern should only match at word boundaries",
  144. )
  145. # test backslashes
  146. condition = {
  147. "kind": "event_match",
  148. "key": "content.body",
  149. "pattern": r"f\oobaz",
  150. }
  151. self._assert_matches(
  152. condition,
  153. {"body": r"F\oobaz"},
  154. "backslash should match itself",
  155. )
  156. condition = {
  157. "kind": "event_match",
  158. "key": "content.body",
  159. "pattern": r"f\?obaz",
  160. }
  161. self._assert_matches(
  162. condition,
  163. {"body": r"F\oobaz"},
  164. r"? after \ should match any character",
  165. )
  166. def test_event_match_non_body(self) -> None:
  167. """Check that event_match conditions on other keys work as expected"""
  168. # if the key is anything other than 'content.body', the pattern must match the
  169. # whole value.
  170. # non-wildcards should match
  171. condition = {
  172. "kind": "event_match",
  173. "key": "content.value",
  174. "pattern": "foobaz",
  175. }
  176. self._assert_matches(
  177. condition,
  178. {"value": "FoobaZ"},
  179. "patterns should match and be case-insensitive",
  180. )
  181. self._assert_not_matches(
  182. condition,
  183. {"value": "xFoobaZ"},
  184. "pattern should only match at the start/end of the value",
  185. )
  186. self._assert_not_matches(
  187. condition,
  188. {"value": "FoobaZz"},
  189. "pattern should only match at the start/end of the value",
  190. )
  191. # it should work on frozendicts too
  192. self._assert_matches(
  193. condition,
  194. frozendict.frozendict({"value": "FoobaZ"}),
  195. "patterns should match on frozendicts",
  196. )
  197. # wildcards should match
  198. condition = {
  199. "kind": "event_match",
  200. "key": "content.value",
  201. "pattern": "f?o*baz",
  202. }
  203. self._assert_matches(
  204. condition,
  205. {"value": "FoobarbaZ"},
  206. "* should match string and pattern should be case-insensitive",
  207. )
  208. self._assert_matches(
  209. condition, {"value": "foobaz"}, "* should match 0 characters"
  210. )
  211. self._assert_not_matches(
  212. condition, {"value": "fobbaz"}, "? should not match 0 characters"
  213. )
  214. self._assert_not_matches(
  215. condition, {"value": "fiiobaz"}, "? should not match 2 characters"
  216. )
  217. self._assert_not_matches(
  218. condition,
  219. {"value": "xfooxbaz"},
  220. "pattern should only match at the start/end of the value",
  221. )
  222. self._assert_not_matches(
  223. condition,
  224. {"value": "fooxbazx"},
  225. "pattern should only match at the start/end of the value",
  226. )
  227. self._assert_not_matches(
  228. condition,
  229. {"value": "x\nfooxbaz"},
  230. "pattern should not match after a newline",
  231. )
  232. self._assert_not_matches(
  233. condition,
  234. {"value": "fooxbaz\nx"},
  235. "pattern should not match before a newline",
  236. )
  237. def test_no_body(self) -> None:
  238. """Not having a body shouldn't break the evaluator."""
  239. evaluator = self._get_evaluator({})
  240. condition = {
  241. "kind": "contains_display_name",
  242. }
  243. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  244. def test_invalid_body(self) -> None:
  245. """A non-string body should not break the evaluator."""
  246. condition = {
  247. "kind": "contains_display_name",
  248. }
  249. for body in (1, True, {"foo": "bar"}):
  250. evaluator = self._get_evaluator({"body": body})
  251. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  252. def test_tweaks_for_actions(self) -> None:
  253. """
  254. This tests the behaviour of tweaks_for_actions.
  255. """
  256. actions = [
  257. {"set_tweak": "sound", "value": "default"},
  258. {"set_tweak": "highlight"},
  259. "notify",
  260. ]
  261. self.assertEqual(
  262. tweaks_for_actions(actions),
  263. {"sound": "default", "highlight": True},
  264. )
  265. def test_relation_match(self) -> None:
  266. """Test the relation_match push rule kind."""
  267. # Check if the experimental feature is disabled.
  268. evaluator = self._get_evaluator(
  269. {}, {"m.annotation": {("@user:test", "m.reaction")}}
  270. )
  271. # A push rule evaluator with the experimental rule enabled.
  272. evaluator = self._get_evaluator(
  273. {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
  274. )
  275. # Check just relation type.
  276. condition = {
  277. "kind": "org.matrix.msc3772.relation_match",
  278. "rel_type": "m.annotation",
  279. }
  280. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  281. # Check relation type and sender.
  282. condition = {
  283. "kind": "org.matrix.msc3772.relation_match",
  284. "rel_type": "m.annotation",
  285. "sender": "@user:test",
  286. }
  287. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  288. condition = {
  289. "kind": "org.matrix.msc3772.relation_match",
  290. "rel_type": "m.annotation",
  291. "sender": "@other:test",
  292. }
  293. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  294. # Check relation type and event type.
  295. condition = {
  296. "kind": "org.matrix.msc3772.relation_match",
  297. "rel_type": "m.annotation",
  298. "type": "m.reaction",
  299. }
  300. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  301. # Check just sender, this fails since rel_type is required.
  302. condition = {
  303. "kind": "org.matrix.msc3772.relation_match",
  304. "sender": "@user:test",
  305. }
  306. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  307. # Check sender glob.
  308. condition = {
  309. "kind": "org.matrix.msc3772.relation_match",
  310. "rel_type": "m.annotation",
  311. "sender": "@*:test",
  312. }
  313. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  314. # Check event type glob.
  315. condition = {
  316. "kind": "org.matrix.msc3772.relation_match",
  317. "rel_type": "m.annotation",
  318. "event_type": "*.reaction",
  319. }
  320. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  321. class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
  322. """Tests for the bulk push rule evaluator"""
  323. servlets = [
  324. synapse.rest.admin.register_servlets_for_client_rest_resource,
  325. login.register_servlets,
  326. register.register_servlets,
  327. room.register_servlets,
  328. ]
  329. def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
  330. # Define an application service so that we can register appservice users
  331. self._service_token = "some_token"
  332. self._service = ApplicationService(
  333. self._service_token,
  334. "as1",
  335. "@as.sender:test",
  336. namespaces={
  337. "users": [
  338. {"regex": "@_as_.*:test", "exclusive": True},
  339. {"regex": "@as.sender:test", "exclusive": True},
  340. ]
  341. },
  342. msc3202_transaction_extensions=True,
  343. )
  344. self.hs.get_datastores().main.services_cache = [self._service]
  345. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  346. [self._service]
  347. )
  348. self._as_user, _ = self.register_appservice_user(
  349. "_as_user", self._service_token
  350. )
  351. self.evaluator = self.hs.get_bulk_push_rule_evaluator()
  352. def test_ignore_appservice_users(self) -> None:
  353. "Test that we don't generate push for appservice users"
  354. user_id = self.register_user("user", "pass")
  355. token = self.login("user", "pass")
  356. room_id = self.helper.create_room_as(user_id, tok=token)
  357. self.get_success(
  358. inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
  359. )
  360. event, context = self.get_success(
  361. create_event(
  362. self.hs,
  363. type=EventTypes.Message,
  364. room_id=room_id,
  365. sender=user_id,
  366. content={"body": "test", "msgtype": "m.text"},
  367. )
  368. )
  369. # Assert the returned push rules do not contain the app service user
  370. rules = self.get_success(self.evaluator._get_rules_for_event(event))
  371. self.assertTrue(self._as_user not in rules)
  372. # Assert that no push actions have been added to the staging table (the
  373. # sender should not be pushed for the event)
  374. users_with_push_actions = self.get_success(
  375. self.hs.get_datastores().main.db_pool.simple_select_onecol(
  376. table="event_push_actions_staging",
  377. keyvalues={"event_id": event.event_id},
  378. retcol="user_id",
  379. desc="test_ignore_appservice_users",
  380. )
  381. )
  382. self.assertEqual(len(users_with_push_actions), 0)