You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

503 lines
16 KiB

  1. // Copyright 2022 The Matrix.org Foundation C.I.C.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //! An implementation of Matrix push rules.
  15. //!
  16. //! The `Cow<_>` type is used extensively within this module to allow creating
  17. //! the base rules as constants (in Rust constants can't require explicit
  18. //! allocation atm).
  19. //!
  20. //! ---
  21. //!
  22. //! Push rules is the system used to determine which events trigger a push (and a
  23. //! bump in notification counts).
  24. //!
  25. //! This consists of a list of "push rules" for each user, where a push rule is a
  26. //! pair of "conditions" and "actions". When a user receives an event Synapse
  27. //! iterates over the list of push rules until it finds one where all the conditions
  28. //! match the event, at which point "actions" describe the outcome (e.g. notify,
  29. //! highlight, etc).
  30. //!
  31. //! Push rules are split up into 5 different "kinds" (aka "priority classes"), which
  32. //! are run in order:
  33. //! 1. Override — highest priority rules, e.g. always ignore notices
  34. //! 2. Content — content specific rules, e.g. @ notifications
  35. //! 3. Room — per room rules, e.g. enable/disable notifications for all messages
  36. //! in a room
  37. //! 4. Sender — per sender rules, e.g. never notify for messages from a given
  38. //! user
  39. //! 5. Underride — the lowest priority "default" rules, e.g. notify for every
  40. //! message.
  41. //!
  42. //! The set of "base rules" are the list of rules that every user has by default. A
  43. //! user can modify their copy of the push rules in one of three ways:
  44. //!
  45. //! 1. Adding a new push rule of a certain kind
  46. //! 2. Changing the actions of a base rule
  47. //! 3. Enabling/disabling a base rule.
  48. //!
  49. //! The base rules are split into whether they come before or after a particular
  50. //! kind, so the order of push rule evaluation would be: base rules for before
  51. //! "override" kind, user defined "override" rules, base rules after "override"
  52. //! kind, etc, etc.
  53. use std::borrow::Cow;
  54. use std::collections::{BTreeMap, HashMap, HashSet};
  55. use anyhow::{Context, Error};
  56. use log::warn;
  57. use pyo3::prelude::*;
  58. use pythonize::pythonize;
  59. use serde::de::Error as _;
  60. use serde::{Deserialize, Serialize};
  61. use serde_json::Value;
  62. mod base_rules;
  63. /// Called when registering modules with python.
  64. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
  65. let child_module = PyModule::new(py, "push")?;
  66. child_module.add_class::<PushRule>()?;
  67. child_module.add_class::<PushRules>()?;
  68. child_module.add_class::<FilteredPushRules>()?;
  69. child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
  70. m.add_submodule(child_module)?;
  71. // We need to manually add the module to sys.modules to make `from
  72. // synapse.synapse_rust import push` work.
  73. py.import("sys")?
  74. .getattr("modules")?
  75. .set_item("synapse.synapse_rust.push", child_module)?;
  76. Ok(())
  77. }
  78. #[pyfunction]
  79. fn get_base_rule_ids() -> HashSet<&'static str> {
  80. base_rules::BASE_RULES_BY_ID.keys().copied().collect()
  81. }
  82. /// A single push rule for a user.
  83. #[derive(Debug, Clone)]
  84. #[pyclass(frozen)]
  85. pub struct PushRule {
  86. /// A unique ID for this rule
  87. pub rule_id: Cow<'static, str>,
  88. /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python)
  89. #[pyo3(get)]
  90. pub priority_class: i32,
  91. /// The conditions that must all match for actions to be applied
  92. pub conditions: Cow<'static, [Condition]>,
  93. /// The actions to apply if all conditions are met
  94. pub actions: Cow<'static, [Action]>,
  95. /// Whether this is a base rule
  96. #[pyo3(get)]
  97. pub default: bool,
  98. /// Whether this is enabled by default
  99. #[pyo3(get)]
  100. pub default_enabled: bool,
  101. }
  102. #[pymethods]
  103. impl PushRule {
  104. #[staticmethod]
  105. pub fn from_db(
  106. rule_id: String,
  107. priority_class: i32,
  108. conditions: &str,
  109. actions: &str,
  110. ) -> Result<PushRule, Error> {
  111. let conditions = serde_json::from_str(conditions).context("parsing conditions")?;
  112. let actions = serde_json::from_str(actions).context("parsing actions")?;
  113. Ok(PushRule {
  114. rule_id: Cow::Owned(rule_id),
  115. priority_class,
  116. conditions,
  117. actions,
  118. default: false,
  119. default_enabled: true,
  120. })
  121. }
  122. #[getter]
  123. fn rule_id(&self) -> &str {
  124. &self.rule_id
  125. }
  126. #[getter]
  127. fn actions(&self) -> Vec<Action> {
  128. self.actions.clone().into_owned()
  129. }
  130. #[getter]
  131. fn conditions(&self) -> Vec<Condition> {
  132. self.conditions.clone().into_owned()
  133. }
  134. fn __repr__(&self) -> String {
  135. format!(
  136. "<PushRule rule_id={}, conditions={:?}, actions={:?}>",
  137. self.rule_id, self.conditions, self.actions
  138. )
  139. }
  140. }
  141. /// The "action" Synapse should perform for a matching push rule.
  142. #[derive(Debug, Clone, PartialEq, Eq)]
  143. pub enum Action {
  144. DontNotify,
  145. Notify,
  146. Coalesce,
  147. SetTweak(SetTweak),
  148. // An unrecognized custom action.
  149. Unknown(Value),
  150. }
  151. impl IntoPy<PyObject> for Action {
  152. fn into_py(self, py: Python<'_>) -> PyObject {
  153. // When we pass the `Action` struct to Python we want it to be converted
  154. // to a dict. We use `pythonize`, which converts the struct using the
  155. // `serde` serialization.
  156. pythonize(py, &self).expect("valid action")
  157. }
  158. }
  159. /// The body of a `SetTweak` push action.
  160. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  161. pub struct SetTweak {
  162. set_tweak: Cow<'static, str>,
  163. #[serde(skip_serializing_if = "Option::is_none")]
  164. value: Option<TweakValue>,
  165. // This picks up any other fields that may have been added by clients.
  166. // These get added when we convert the `Action` to a python object.
  167. #[serde(flatten)]
  168. other_keys: Value,
  169. }
  170. /// The value of a `set_tweak`.
  171. ///
  172. /// We need this (rather than using `TweakValue` directly) so that we can use
  173. /// `&'static str` in the value when defining the constant base rules.
  174. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  175. #[serde(untagged)]
  176. pub enum TweakValue {
  177. String(Cow<'static, str>),
  178. Other(Value),
  179. }
  180. impl Serialize for Action {
  181. fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  182. where
  183. S: serde::Serializer,
  184. {
  185. match self {
  186. Action::DontNotify => serializer.serialize_str("dont_notify"),
  187. Action::Notify => serializer.serialize_str("notify"),
  188. Action::Coalesce => serializer.serialize_str("coalesce"),
  189. Action::SetTweak(tweak) => tweak.serialize(serializer),
  190. Action::Unknown(value) => value.serialize(serializer),
  191. }
  192. }
  193. }
  194. /// Simple helper class for deserializing Action from JSON.
  195. #[derive(Deserialize)]
  196. #[serde(untagged)]
  197. enum ActionDeserializeHelper {
  198. Str(String),
  199. SetTweak(SetTweak),
  200. Unknown(Value),
  201. }
  202. impl<'de> Deserialize<'de> for Action {
  203. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  204. where
  205. D: serde::Deserializer<'de>,
  206. {
  207. let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?;
  208. match helper {
  209. ActionDeserializeHelper::Str(s) => match &*s {
  210. "dont_notify" => Ok(Action::DontNotify),
  211. "notify" => Ok(Action::Notify),
  212. "coalesce" => Ok(Action::Coalesce),
  213. _ => Err(D::Error::custom("unrecognized action")),
  214. },
  215. ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)),
  216. ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)),
  217. }
  218. }
  219. }
  220. /// A condition used in push rules to match against an event.
  221. ///
  222. /// We need this split as `serde` doesn't give us the ability to have a
  223. /// "catchall" variant in tagged enums.
  224. #[derive(Serialize, Deserialize, Debug, Clone)]
  225. #[serde(untagged)]
  226. pub enum Condition {
  227. /// A recognized condition that we can match against
  228. Known(KnownCondition),
  229. /// An unrecognized condition that we ignore.
  230. Unknown(Value),
  231. }
  232. /// The set of "known" conditions that we can handle.
  233. #[derive(Serialize, Deserialize, Debug, Clone)]
  234. #[serde(rename_all = "snake_case")]
  235. #[serde(tag = "kind")]
  236. pub enum KnownCondition {
  237. EventMatch(EventMatchCondition),
  238. ContainsDisplayName,
  239. RoomMemberCount {
  240. #[serde(skip_serializing_if = "Option::is_none")]
  241. is: Option<Cow<'static, str>>,
  242. },
  243. SenderNotificationPermission {
  244. key: Cow<'static, str>,
  245. },
  246. #[serde(rename = "org.matrix.msc3772.relation_match")]
  247. RelationMatch {
  248. rel_type: Cow<'static, str>,
  249. #[serde(skip_serializing_if = "Option::is_none")]
  250. sender: Option<Cow<'static, str>>,
  251. #[serde(skip_serializing_if = "Option::is_none")]
  252. sender_type: Option<Cow<'static, str>>,
  253. },
  254. }
  255. impl IntoPy<PyObject> for Condition {
  256. fn into_py(self, py: Python<'_>) -> PyObject {
  257. pythonize(py, &self).expect("valid condition")
  258. }
  259. }
  260. /// The body of a [`Condition::EventMatch`]
  261. #[derive(Serialize, Deserialize, Debug, Clone)]
  262. pub struct EventMatchCondition {
  263. key: Cow<'static, str>,
  264. #[serde(skip_serializing_if = "Option::is_none")]
  265. pattern: Option<Cow<'static, str>>,
  266. #[serde(skip_serializing_if = "Option::is_none")]
  267. pattern_type: Option<Cow<'static, str>>,
  268. }
  269. /// The collection of push rules for a user.
  270. #[derive(Debug, Clone, Default)]
  271. #[pyclass(frozen)]
  272. struct PushRules {
  273. /// Custom push rules that override a base rule.
  274. overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
  275. /// Custom rules that come between the prepend/append override base rules.
  276. override_rules: Vec<PushRule>,
  277. /// Custom rules that come before the base content rules.
  278. content: Vec<PushRule>,
  279. /// Custom rules that come before the base room rules.
  280. room: Vec<PushRule>,
  281. /// Custom rules that come before the base sender rules.
  282. sender: Vec<PushRule>,
  283. /// Custom rules that come before the base underride rules.
  284. underride: Vec<PushRule>,
  285. }
  286. #[pymethods]
  287. impl PushRules {
  288. #[new]
  289. fn new(rules: Vec<PushRule>) -> PushRules {
  290. let mut push_rules: PushRules = Default::default();
  291. for rule in rules {
  292. if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) {
  293. push_rules.overridden_base_rules.insert(
  294. rule.rule_id.clone(),
  295. PushRule {
  296. actions: rule.actions.clone(),
  297. ..o.clone()
  298. },
  299. );
  300. continue;
  301. }
  302. match rule.priority_class {
  303. 5 => push_rules.override_rules.push(rule),
  304. 4 => push_rules.content.push(rule),
  305. 3 => push_rules.room.push(rule),
  306. 2 => push_rules.sender.push(rule),
  307. 1 => push_rules.underride.push(rule),
  308. _ => {
  309. warn!(
  310. "Unrecognized priority class for rule {}: {}",
  311. rule.rule_id, rule.priority_class
  312. );
  313. }
  314. }
  315. }
  316. push_rules
  317. }
  318. /// Returns the list of all rules, including base rules, in the order they
  319. /// should be executed in.
  320. fn rules(&self) -> Vec<PushRule> {
  321. self.iter().cloned().collect()
  322. }
  323. }
  324. impl PushRules {
  325. /// Iterates over all the rules, including base rules, in the order they
  326. /// should be executed in.
  327. pub fn iter(&self) -> impl Iterator<Item = &PushRule> {
  328. base_rules::BASE_PREPEND_OVERRIDE_RULES
  329. .iter()
  330. .chain(self.override_rules.iter())
  331. .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
  332. .chain(self.content.iter())
  333. .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
  334. .chain(self.room.iter())
  335. .chain(self.sender.iter())
  336. .chain(self.underride.iter())
  337. .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter())
  338. .map(|rule| {
  339. self.overridden_base_rules
  340. .get(&*rule.rule_id)
  341. .unwrap_or(rule)
  342. })
  343. }
  344. }
  345. /// A wrapper around `PushRules` that checks the enabled state of rules and
  346. /// filters out disabled experimental rules.
  347. #[derive(Debug, Clone, Default)]
  348. #[pyclass(frozen)]
  349. pub struct FilteredPushRules {
  350. push_rules: PushRules,
  351. enabled_map: BTreeMap<String, bool>,
  352. msc3786_enabled: bool,
  353. msc3772_enabled: bool,
  354. }
  355. #[pymethods]
  356. impl FilteredPushRules {
  357. #[new]
  358. fn py_new(
  359. push_rules: PushRules,
  360. enabled_map: BTreeMap<String, bool>,
  361. msc3786_enabled: bool,
  362. msc3772_enabled: bool,
  363. ) -> Self {
  364. Self {
  365. push_rules,
  366. enabled_map,
  367. msc3786_enabled,
  368. msc3772_enabled,
  369. }
  370. }
  371. /// Returns the list of all rules and their enabled state, including base
  372. /// rules, in the order they should be executed in.
  373. fn rules(&self) -> Vec<(PushRule, bool)> {
  374. self.iter().map(|(r, e)| (r.clone(), e)).collect()
  375. }
  376. }
  377. impl FilteredPushRules {
  378. /// Iterates over all the rules and their enabled state, including base
  379. /// rules, in the order they should be executed in.
  380. fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
  381. self.push_rules
  382. .iter()
  383. .filter(|rule| {
  384. // Ignore disabled experimental push rules
  385. if !self.msc3786_enabled
  386. && rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
  387. {
  388. return false;
  389. }
  390. if !self.msc3772_enabled
  391. && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
  392. {
  393. return false;
  394. }
  395. true
  396. })
  397. .map(|r| {
  398. let enabled = *self
  399. .enabled_map
  400. .get(&*r.rule_id)
  401. .unwrap_or(&r.default_enabled);
  402. (r, enabled)
  403. })
  404. }
  405. }
  406. #[test]
  407. fn test_serialize_condition() {
  408. let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition {
  409. key: "content.body".into(),
  410. pattern: Some("coffee".into()),
  411. pattern_type: None,
  412. }));
  413. let json = serde_json::to_string(&condition).unwrap();
  414. assert_eq!(
  415. json,
  416. r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#
  417. )
  418. }
  419. #[test]
  420. fn test_deserialize_condition() {
  421. let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#;
  422. let _: Condition = serde_json::from_str(json).unwrap();
  423. }
  424. #[test]
  425. fn test_deserialize_custom_condition() {
  426. let json = r#"{"kind":"custom_tag"}"#;
  427. let condition: Condition = serde_json::from_str(json).unwrap();
  428. assert!(matches!(condition, Condition::Unknown(_)));
  429. let new_json = serde_json::to_string(&condition).unwrap();
  430. assert_eq!(json, new_json);
  431. }
  432. #[test]
  433. fn test_deserialize_action() {
  434. let _: Action = serde_json::from_str(r#""notify""#).unwrap();
  435. let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap();
  436. let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
  437. let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
  438. }
  439. #[test]
  440. fn test_custom_action() {
  441. let json = r#"{"some_custom":"action_fields"}"#;
  442. let action: Action = serde_json::from_str(json).unwrap();
  443. assert!(matches!(action, Action::Unknown(_)));
  444. let new_json = serde_json::to_string(&action).unwrap();
  445. assert_eq!(json, new_json);
  446. }