You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

588 lines
19 KiB

  1. // Copyright 2022 The Matrix.org Foundation C.I.C.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //! An implementation of Matrix push rules.
  15. //!
  16. //! The `Cow<_>` type is used extensively within this module to allow creating
  17. //! the base rules as constants (in Rust constants can't require explicit
  18. //! allocation atm).
  19. //!
  20. //! ---
  21. //!
  22. //! Push rules is the system used to determine which events trigger a push (and a
  23. //! bump in notification counts).
  24. //!
  25. //! This consists of a list of "push rules" for each user, where a push rule is a
  26. //! pair of "conditions" and "actions". When a user receives an event Synapse
  27. //! iterates over the list of push rules until it finds one where all the conditions
  28. //! match the event, at which point "actions" describe the outcome (e.g. notify,
  29. //! highlight, etc).
  30. //!
  31. //! Push rules are split up into 5 different "kinds" (aka "priority classes"), which
  32. //! are run in order:
  33. //! 1. Override — highest priority rules, e.g. always ignore notices
  34. //! 2. Content — content specific rules, e.g. @ notifications
  35. //! 3. Room — per room rules, e.g. enable/disable notifications for all messages
  36. //! in a room
  37. //! 4. Sender — per sender rules, e.g. never notify for messages from a given
  38. //! user
  39. //! 5. Underride — the lowest priority "default" rules, e.g. notify for every
  40. //! message.
  41. //!
  42. //! The set of "base rules" are the list of rules that every user has by default. A
  43. //! user can modify their copy of the push rules in one of three ways:
  44. //! 1. Adding a new push rule of a certain kind
  45. //! 2. Changing the actions of a base rule
  46. //! 3. Enabling/disabling a base rule.
  47. //!
  48. //! The base rules are split into whether they come before or after a particular
  49. //! kind, so the order of push rule evaluation would be: base rules for before
  50. //! "override" kind, user defined "override" rules, base rules after "override"
  51. //! kind, etc, etc.
  52. use std::borrow::Cow;
  53. use std::collections::{BTreeMap, HashMap, HashSet};
  54. use anyhow::{Context, Error};
  55. use log::warn;
  56. use pyo3::prelude::*;
  57. use pythonize::{depythonize, pythonize};
  58. use serde::de::Error as _;
  59. use serde::{Deserialize, Serialize};
  60. use serde_json::Value;
  61. use self::evaluator::PushRuleEvaluator;
  62. mod base_rules;
  63. pub mod evaluator;
  64. pub mod utils;
  65. /// Called when registering modules with python.
  66. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
  67. let child_module = PyModule::new(py, "push")?;
  68. child_module.add_class::<PushRule>()?;
  69. child_module.add_class::<PushRules>()?;
  70. child_module.add_class::<FilteredPushRules>()?;
  71. child_module.add_class::<PushRuleEvaluator>()?;
  72. child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
  73. m.add_submodule(child_module)?;
  74. // We need to manually add the module to sys.modules to make `from
  75. // synapse.synapse_rust import push` work.
  76. py.import("sys")?
  77. .getattr("modules")?
  78. .set_item("synapse.synapse_rust.push", child_module)?;
  79. Ok(())
  80. }
  81. #[pyfunction]
  82. fn get_base_rule_ids() -> HashSet<&'static str> {
  83. base_rules::BASE_RULES_BY_ID.keys().copied().collect()
  84. }
  85. /// A single push rule for a user.
  86. #[derive(Debug, Clone)]
  87. #[pyclass(frozen)]
  88. pub struct PushRule {
  89. /// A unique ID for this rule
  90. pub rule_id: Cow<'static, str>,
  91. /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python)
  92. #[pyo3(get)]
  93. pub priority_class: i32,
  94. /// The conditions that must all match for actions to be applied
  95. pub conditions: Cow<'static, [Condition]>,
  96. /// The actions to apply if all conditions are met
  97. pub actions: Cow<'static, [Action]>,
  98. /// Whether this is a base rule
  99. #[pyo3(get)]
  100. pub default: bool,
  101. /// Whether this is enabled by default
  102. #[pyo3(get)]
  103. pub default_enabled: bool,
  104. }
  105. #[pymethods]
  106. impl PushRule {
  107. #[staticmethod]
  108. pub fn from_db(
  109. rule_id: String,
  110. priority_class: i32,
  111. conditions: &str,
  112. actions: &str,
  113. ) -> Result<PushRule, Error> {
  114. let conditions = serde_json::from_str(conditions).context("parsing conditions")?;
  115. let actions = serde_json::from_str(actions).context("parsing actions")?;
  116. Ok(PushRule {
  117. rule_id: Cow::Owned(rule_id),
  118. priority_class,
  119. conditions,
  120. actions,
  121. default: false,
  122. default_enabled: true,
  123. })
  124. }
  125. #[getter]
  126. fn rule_id(&self) -> &str {
  127. &self.rule_id
  128. }
  129. #[getter]
  130. fn actions(&self) -> Vec<Action> {
  131. self.actions.clone().into_owned()
  132. }
  133. #[getter]
  134. fn conditions(&self) -> Vec<Condition> {
  135. self.conditions.clone().into_owned()
  136. }
  137. fn __repr__(&self) -> String {
  138. format!(
  139. "<PushRule rule_id={}, conditions={:?}, actions={:?}>",
  140. self.rule_id, self.conditions, self.actions
  141. )
  142. }
  143. }
  144. /// The "action" Synapse should perform for a matching push rule.
  145. #[derive(Debug, Clone, PartialEq, Eq)]
  146. pub enum Action {
  147. DontNotify,
  148. Notify,
  149. Coalesce,
  150. SetTweak(SetTweak),
  151. // An unrecognized custom action.
  152. Unknown(Value),
  153. }
  154. impl IntoPy<PyObject> for Action {
  155. fn into_py(self, py: Python<'_>) -> PyObject {
  156. // When we pass the `Action` struct to Python we want it to be converted
  157. // to a dict. We use `pythonize`, which converts the struct using the
  158. // `serde` serialization.
  159. pythonize(py, &self).expect("valid action")
  160. }
  161. }
  162. /// The body of a `SetTweak` push action.
  163. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  164. pub struct SetTweak {
  165. set_tweak: Cow<'static, str>,
  166. #[serde(skip_serializing_if = "Option::is_none")]
  167. value: Option<TweakValue>,
  168. // This picks up any other fields that may have been added by clients.
  169. // These get added when we convert the `Action` to a python object.
  170. #[serde(flatten)]
  171. other_keys: Value,
  172. }
  173. /// The value of a `set_tweak`.
  174. ///
  175. /// We need this (rather than using `TweakValue` directly) so that we can use
  176. /// `&'static str` in the value when defining the constant base rules.
  177. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  178. #[serde(untagged)]
  179. pub enum TweakValue {
  180. String(Cow<'static, str>),
  181. Other(Value),
  182. }
  183. impl Serialize for Action {
  184. fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  185. where
  186. S: serde::Serializer,
  187. {
  188. match self {
  189. Action::DontNotify => serializer.serialize_str("dont_notify"),
  190. Action::Notify => serializer.serialize_str("notify"),
  191. Action::Coalesce => serializer.serialize_str("coalesce"),
  192. Action::SetTweak(tweak) => tweak.serialize(serializer),
  193. Action::Unknown(value) => value.serialize(serializer),
  194. }
  195. }
  196. }
  197. /// Simple helper class for deserializing Action from JSON.
  198. #[derive(Deserialize)]
  199. #[serde(untagged)]
  200. enum ActionDeserializeHelper {
  201. Str(String),
  202. SetTweak(SetTweak),
  203. Unknown(Value),
  204. }
  205. impl<'de> Deserialize<'de> for Action {
  206. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  207. where
  208. D: serde::Deserializer<'de>,
  209. {
  210. let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?;
  211. match helper {
  212. ActionDeserializeHelper::Str(s) => match &*s {
  213. "dont_notify" => Ok(Action::DontNotify),
  214. "notify" => Ok(Action::Notify),
  215. "coalesce" => Ok(Action::Coalesce),
  216. _ => Err(D::Error::custom("unrecognized action")),
  217. },
  218. ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)),
  219. ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)),
  220. }
  221. }
  222. }
  223. /// A condition used in push rules to match against an event.
  224. ///
  225. /// We need this split as `serde` doesn't give us the ability to have a
  226. /// "catchall" variant in tagged enums.
  227. #[derive(Serialize, Deserialize, Debug, Clone)]
  228. #[serde(untagged)]
  229. pub enum Condition {
  230. /// A recognized condition that we can match against
  231. Known(KnownCondition),
  232. /// An unrecognized condition that we ignore.
  233. Unknown(Value),
  234. }
  235. /// The set of "known" conditions that we can handle.
  236. #[derive(Serialize, Deserialize, Debug, Clone)]
  237. #[serde(rename_all = "snake_case")]
  238. #[serde(tag = "kind")]
  239. pub enum KnownCondition {
  240. EventMatch(EventMatchCondition),
  241. #[serde(rename = "im.nheko.msc3664.related_event_match")]
  242. RelatedEventMatch(RelatedEventMatchCondition),
  243. #[serde(rename = "org.matrix.msc3952.is_user_mention")]
  244. IsUserMention,
  245. #[serde(rename = "org.matrix.msc3952.is_room_mention")]
  246. IsRoomMention,
  247. ContainsDisplayName,
  248. RoomMemberCount {
  249. #[serde(skip_serializing_if = "Option::is_none")]
  250. is: Option<Cow<'static, str>>,
  251. },
  252. SenderNotificationPermission {
  253. key: Cow<'static, str>,
  254. },
  255. #[serde(rename = "org.matrix.msc3931.room_version_supports")]
  256. RoomVersionSupports {
  257. feature: Cow<'static, str>,
  258. },
  259. }
  260. impl IntoPy<PyObject> for Condition {
  261. fn into_py(self, py: Python<'_>) -> PyObject {
  262. pythonize(py, &self).expect("valid condition")
  263. }
  264. }
  265. impl<'source> FromPyObject<'source> for Condition {
  266. fn extract(ob: &'source PyAny) -> PyResult<Self> {
  267. Ok(depythonize(ob)?)
  268. }
  269. }
  270. /// The body of a [`Condition::EventMatch`]
  271. #[derive(Serialize, Deserialize, Debug, Clone)]
  272. pub struct EventMatchCondition {
  273. pub key: Cow<'static, str>,
  274. #[serde(skip_serializing_if = "Option::is_none")]
  275. pub pattern: Option<Cow<'static, str>>,
  276. #[serde(skip_serializing_if = "Option::is_none")]
  277. pub pattern_type: Option<Cow<'static, str>>,
  278. }
  279. /// The body of a [`Condition::RelatedEventMatch`]
  280. #[derive(Serialize, Deserialize, Debug, Clone)]
  281. pub struct RelatedEventMatchCondition {
  282. #[serde(skip_serializing_if = "Option::is_none")]
  283. pub key: Option<Cow<'static, str>>,
  284. #[serde(skip_serializing_if = "Option::is_none")]
  285. pub pattern: Option<Cow<'static, str>>,
  286. #[serde(skip_serializing_if = "Option::is_none")]
  287. pub pattern_type: Option<Cow<'static, str>>,
  288. pub rel_type: Cow<'static, str>,
  289. #[serde(skip_serializing_if = "Option::is_none")]
  290. pub include_fallbacks: Option<bool>,
  291. }
  292. /// The collection of push rules for a user.
  293. #[derive(Debug, Clone, Default)]
  294. #[pyclass(frozen)]
  295. pub struct PushRules {
  296. /// Custom push rules that override a base rule.
  297. overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
  298. /// Custom rules that come between the prepend/append override base rules.
  299. override_rules: Vec<PushRule>,
  300. /// Custom rules that come before the base content rules.
  301. content: Vec<PushRule>,
  302. /// Custom rules that come before the base room rules.
  303. room: Vec<PushRule>,
  304. /// Custom rules that come before the base sender rules.
  305. sender: Vec<PushRule>,
  306. /// Custom rules that come before the base underride rules.
  307. underride: Vec<PushRule>,
  308. }
  309. #[pymethods]
  310. impl PushRules {
  311. #[new]
  312. pub fn new(rules: Vec<PushRule>) -> PushRules {
  313. let mut push_rules: PushRules = Default::default();
  314. for rule in rules {
  315. if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) {
  316. push_rules.overridden_base_rules.insert(
  317. rule.rule_id.clone(),
  318. PushRule {
  319. actions: rule.actions.clone(),
  320. ..o.clone()
  321. },
  322. );
  323. continue;
  324. }
  325. match rule.priority_class {
  326. 5 => push_rules.override_rules.push(rule),
  327. 4 => push_rules.content.push(rule),
  328. 3 => push_rules.room.push(rule),
  329. 2 => push_rules.sender.push(rule),
  330. 1 => push_rules.underride.push(rule),
  331. _ => {
  332. warn!(
  333. "Unrecognized priority class for rule {}: {}",
  334. rule.rule_id, rule.priority_class
  335. );
  336. }
  337. }
  338. }
  339. push_rules
  340. }
  341. /// Returns the list of all rules, including base rules, in the order they
  342. /// should be executed in.
  343. fn rules(&self) -> Vec<PushRule> {
  344. self.iter().cloned().collect()
  345. }
  346. }
  347. impl PushRules {
  348. /// Iterates over all the rules, including base rules, in the order they
  349. /// should be executed in.
  350. pub fn iter(&self) -> impl Iterator<Item = &PushRule> {
  351. base_rules::BASE_PREPEND_OVERRIDE_RULES
  352. .iter()
  353. .chain(self.override_rules.iter())
  354. .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
  355. .chain(self.content.iter())
  356. .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
  357. .chain(self.room.iter())
  358. .chain(self.sender.iter())
  359. .chain(self.underride.iter())
  360. .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter())
  361. .map(|rule| {
  362. self.overridden_base_rules
  363. .get(&*rule.rule_id)
  364. .unwrap_or(rule)
  365. })
  366. }
  367. }
  368. /// A wrapper around `PushRules` that checks the enabled state of rules and
  369. /// filters out disabled experimental rules.
  370. #[derive(Debug, Clone, Default)]
  371. #[pyclass(frozen)]
  372. pub struct FilteredPushRules {
  373. push_rules: PushRules,
  374. enabled_map: BTreeMap<String, bool>,
  375. msc1767_enabled: bool,
  376. msc3381_polls_enabled: bool,
  377. msc3664_enabled: bool,
  378. msc3952_intentional_mentions: bool,
  379. }
  380. #[pymethods]
  381. impl FilteredPushRules {
  382. #[new]
  383. pub fn py_new(
  384. push_rules: PushRules,
  385. enabled_map: BTreeMap<String, bool>,
  386. msc1767_enabled: bool,
  387. msc3381_polls_enabled: bool,
  388. msc3664_enabled: bool,
  389. msc3952_intentional_mentions: bool,
  390. ) -> Self {
  391. Self {
  392. push_rules,
  393. enabled_map,
  394. msc1767_enabled,
  395. msc3381_polls_enabled,
  396. msc3664_enabled,
  397. msc3952_intentional_mentions,
  398. }
  399. }
  400. /// Returns the list of all rules and their enabled state, including base
  401. /// rules, in the order they should be executed in.
  402. fn rules(&self) -> Vec<(PushRule, bool)> {
  403. self.iter().map(|(r, e)| (r.clone(), e)).collect()
  404. }
  405. }
  406. impl FilteredPushRules {
  407. /// Iterates over all the rules and their enabled state, including base
  408. /// rules, in the order they should be executed in.
  409. fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
  410. self.push_rules
  411. .iter()
  412. .filter(|rule| {
  413. // Ignore disabled experimental push rules
  414. if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") {
  415. return false;
  416. }
  417. if !self.msc3664_enabled
  418. && rule.rule_id == "global/override/.im.nheko.msc3664.reply"
  419. {
  420. return false;
  421. }
  422. if !self.msc3381_polls_enabled && rule.rule_id.contains("org.matrix.msc3930") {
  423. return false;
  424. }
  425. if !self.msc3952_intentional_mentions && rule.rule_id.contains("org.matrix.msc3952")
  426. {
  427. return false;
  428. }
  429. true
  430. })
  431. .map(|r| {
  432. let enabled = *self
  433. .enabled_map
  434. .get(&*r.rule_id)
  435. .unwrap_or(&r.default_enabled);
  436. (r, enabled)
  437. })
  438. }
  439. }
  440. #[test]
  441. fn test_serialize_condition() {
  442. let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition {
  443. key: "content.body".into(),
  444. pattern: Some("coffee".into()),
  445. pattern_type: None,
  446. }));
  447. let json = serde_json::to_string(&condition).unwrap();
  448. assert_eq!(
  449. json,
  450. r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#
  451. )
  452. }
  453. #[test]
  454. fn test_deserialize_condition() {
  455. let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#;
  456. let _: Condition = serde_json::from_str(json).unwrap();
  457. }
  458. #[test]
  459. fn test_deserialize_unstable_msc3664_condition() {
  460. let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#;
  461. let condition: Condition = serde_json::from_str(json).unwrap();
  462. assert!(matches!(
  463. condition,
  464. Condition::Known(KnownCondition::RelatedEventMatch(_))
  465. ));
  466. }
  467. #[test]
  468. fn test_deserialize_unstable_msc3931_condition() {
  469. let json =
  470. r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#;
  471. let condition: Condition = serde_json::from_str(json).unwrap();
  472. assert!(matches!(
  473. condition,
  474. Condition::Known(KnownCondition::RoomVersionSupports { feature: _ })
  475. ));
  476. }
  477. #[test]
  478. fn test_deserialize_unstable_msc3952_user_condition() {
  479. let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#;
  480. let condition: Condition = serde_json::from_str(json).unwrap();
  481. assert!(matches!(
  482. condition,
  483. Condition::Known(KnownCondition::IsUserMention)
  484. ));
  485. }
  486. #[test]
  487. fn test_deserialize_unstable_msc3952_room_condition() {
  488. let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#;
  489. let condition: Condition = serde_json::from_str(json).unwrap();
  490. assert!(matches!(
  491. condition,
  492. Condition::Known(KnownCondition::IsRoomMention)
  493. ));
  494. }
  495. #[test]
  496. fn test_deserialize_custom_condition() {
  497. let json = r#"{"kind":"custom_tag"}"#;
  498. let condition: Condition = serde_json::from_str(json).unwrap();
  499. assert!(matches!(condition, Condition::Unknown(_)));
  500. let new_json = serde_json::to_string(&condition).unwrap();
  501. assert_eq!(json, new_json);
  502. }
  503. #[test]
  504. fn test_deserialize_action() {
  505. let _: Action = serde_json::from_str(r#""notify""#).unwrap();
  506. let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap();
  507. let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
  508. let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
  509. }
  510. #[test]
  511. fn test_custom_action() {
  512. let json = r#"{"some_custom":"action_fields"}"#;
  513. let action: Action = serde_json::from_str(json).unwrap();
  514. assert!(matches!(action, Action::Unknown(_)));
  515. let new_json = serde_json::to_string(&action).unwrap();
  516. assert_eq!(json, new_json);
  517. }