You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

515 lines
16 KiB

  1. // Copyright 2022 The Matrix.org Foundation C.I.C.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //! An implementation of Matrix push rules.
  15. //!
  16. //! The `Cow<_>` type is used extensively within this module to allow creating
  17. //! the base rules as constants (in Rust constants can't require explicit
  18. //! allocation atm).
  19. //!
  20. //! ---
  21. //!
  22. //! Push rules is the system used to determine which events trigger a push (and a
  23. //! bump in notification counts).
  24. //!
  25. //! This consists of a list of "push rules" for each user, where a push rule is a
  26. //! pair of "conditions" and "actions". When a user receives an event Synapse
  27. //! iterates over the list of push rules until it finds one where all the conditions
  28. //! match the event, at which point "actions" describe the outcome (e.g. notify,
  29. //! highlight, etc).
  30. //!
  31. //! Push rules are split up into 5 different "kinds" (aka "priority classes"), which
  32. //! are run in order:
  33. //! 1. Override — highest priority rules, e.g. always ignore notices
  34. //! 2. Content — content specific rules, e.g. @ notifications
  35. //! 3. Room — per room rules, e.g. enable/disable notifications for all messages
  36. //! in a room
  37. //! 4. Sender — per sender rules, e.g. never notify for messages from a given
  38. //! user
  39. //! 5. Underride — the lowest priority "default" rules, e.g. notify for every
  40. //! message.
  41. //!
  42. //! The set of "base rules" are the list of rules that every user has by default. A
  43. //! user can modify their copy of the push rules in one of three ways:
  44. //! 1. Adding a new push rule of a certain kind
  45. //! 2. Changing the actions of a base rule
  46. //! 3. Enabling/disabling a base rule.
  47. //!
  48. //! The base rules are split into whether they come before or after a particular
  49. //! kind, so the order of push rule evaluation would be: base rules for before
  50. //! "override" kind, user defined "override" rules, base rules after "override"
  51. //! kind, etc, etc.
  52. use std::borrow::Cow;
  53. use std::collections::{BTreeMap, HashMap, HashSet};
  54. use anyhow::{Context, Error};
  55. use log::warn;
  56. use pyo3::prelude::*;
  57. use pythonize::{depythonize, pythonize};
  58. use serde::de::Error as _;
  59. use serde::{Deserialize, Serialize};
  60. use serde_json::Value;
  61. use self::evaluator::PushRuleEvaluator;
  62. mod base_rules;
  63. pub mod evaluator;
  64. pub mod utils;
  65. /// Called when registering modules with python.
  66. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
  67. let child_module = PyModule::new(py, "push")?;
  68. child_module.add_class::<PushRule>()?;
  69. child_module.add_class::<PushRules>()?;
  70. child_module.add_class::<FilteredPushRules>()?;
  71. child_module.add_class::<PushRuleEvaluator>()?;
  72. child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
  73. m.add_submodule(child_module)?;
  74. // We need to manually add the module to sys.modules to make `from
  75. // synapse.synapse_rust import push` work.
  76. py.import("sys")?
  77. .getattr("modules")?
  78. .set_item("synapse.synapse_rust.push", child_module)?;
  79. Ok(())
  80. }
  81. #[pyfunction]
  82. fn get_base_rule_ids() -> HashSet<&'static str> {
  83. base_rules::BASE_RULES_BY_ID.keys().copied().collect()
  84. }
  85. /// A single push rule for a user.
  86. #[derive(Debug, Clone)]
  87. #[pyclass(frozen)]
  88. pub struct PushRule {
  89. /// A unique ID for this rule
  90. pub rule_id: Cow<'static, str>,
  91. /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python)
  92. #[pyo3(get)]
  93. pub priority_class: i32,
  94. /// The conditions that must all match for actions to be applied
  95. pub conditions: Cow<'static, [Condition]>,
  96. /// The actions to apply if all conditions are met
  97. pub actions: Cow<'static, [Action]>,
  98. /// Whether this is a base rule
  99. #[pyo3(get)]
  100. pub default: bool,
  101. /// Whether this is enabled by default
  102. #[pyo3(get)]
  103. pub default_enabled: bool,
  104. }
  105. #[pymethods]
  106. impl PushRule {
  107. #[staticmethod]
  108. pub fn from_db(
  109. rule_id: String,
  110. priority_class: i32,
  111. conditions: &str,
  112. actions: &str,
  113. ) -> Result<PushRule, Error> {
  114. let conditions = serde_json::from_str(conditions).context("parsing conditions")?;
  115. let actions = serde_json::from_str(actions).context("parsing actions")?;
  116. Ok(PushRule {
  117. rule_id: Cow::Owned(rule_id),
  118. priority_class,
  119. conditions,
  120. actions,
  121. default: false,
  122. default_enabled: true,
  123. })
  124. }
  125. #[getter]
  126. fn rule_id(&self) -> &str {
  127. &self.rule_id
  128. }
  129. #[getter]
  130. fn actions(&self) -> Vec<Action> {
  131. self.actions.clone().into_owned()
  132. }
  133. #[getter]
  134. fn conditions(&self) -> Vec<Condition> {
  135. self.conditions.clone().into_owned()
  136. }
  137. fn __repr__(&self) -> String {
  138. format!(
  139. "<PushRule rule_id={}, conditions={:?}, actions={:?}>",
  140. self.rule_id, self.conditions, self.actions
  141. )
  142. }
  143. }
  144. /// The "action" Synapse should perform for a matching push rule.
  145. #[derive(Debug, Clone, PartialEq, Eq)]
  146. pub enum Action {
  147. DontNotify,
  148. Notify,
  149. Coalesce,
  150. SetTweak(SetTweak),
  151. // An unrecognized custom action.
  152. Unknown(Value),
  153. }
  154. impl IntoPy<PyObject> for Action {
  155. fn into_py(self, py: Python<'_>) -> PyObject {
  156. // When we pass the `Action` struct to Python we want it to be converted
  157. // to a dict. We use `pythonize`, which converts the struct using the
  158. // `serde` serialization.
  159. pythonize(py, &self).expect("valid action")
  160. }
  161. }
  162. /// The body of a `SetTweak` push action.
  163. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  164. pub struct SetTweak {
  165. set_tweak: Cow<'static, str>,
  166. #[serde(skip_serializing_if = "Option::is_none")]
  167. value: Option<TweakValue>,
  168. // This picks up any other fields that may have been added by clients.
  169. // These get added when we convert the `Action` to a python object.
  170. #[serde(flatten)]
  171. other_keys: Value,
  172. }
  173. /// The value of a `set_tweak`.
  174. ///
  175. /// We need this (rather than using `TweakValue` directly) so that we can use
  176. /// `&'static str` in the value when defining the constant base rules.
  177. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  178. #[serde(untagged)]
  179. pub enum TweakValue {
  180. String(Cow<'static, str>),
  181. Other(Value),
  182. }
  183. impl Serialize for Action {
  184. fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  185. where
  186. S: serde::Serializer,
  187. {
  188. match self {
  189. Action::DontNotify => serializer.serialize_str("dont_notify"),
  190. Action::Notify => serializer.serialize_str("notify"),
  191. Action::Coalesce => serializer.serialize_str("coalesce"),
  192. Action::SetTweak(tweak) => tweak.serialize(serializer),
  193. Action::Unknown(value) => value.serialize(serializer),
  194. }
  195. }
  196. }
  197. /// Simple helper class for deserializing Action from JSON.
  198. #[derive(Deserialize)]
  199. #[serde(untagged)]
  200. enum ActionDeserializeHelper {
  201. Str(String),
  202. SetTweak(SetTweak),
  203. Unknown(Value),
  204. }
  205. impl<'de> Deserialize<'de> for Action {
  206. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  207. where
  208. D: serde::Deserializer<'de>,
  209. {
  210. let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?;
  211. match helper {
  212. ActionDeserializeHelper::Str(s) => match &*s {
  213. "dont_notify" => Ok(Action::DontNotify),
  214. "notify" => Ok(Action::Notify),
  215. "coalesce" => Ok(Action::Coalesce),
  216. _ => Err(D::Error::custom("unrecognized action")),
  217. },
  218. ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)),
  219. ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)),
  220. }
  221. }
  222. }
  223. /// A condition used in push rules to match against an event.
  224. ///
  225. /// We need this split as `serde` doesn't give us the ability to have a
  226. /// "catchall" variant in tagged enums.
  227. #[derive(Serialize, Deserialize, Debug, Clone)]
  228. #[serde(untagged)]
  229. pub enum Condition {
  230. /// A recognized condition that we can match against
  231. Known(KnownCondition),
  232. /// An unrecognized condition that we ignore.
  233. Unknown(Value),
  234. }
  235. /// The set of "known" conditions that we can handle.
  236. #[derive(Serialize, Deserialize, Debug, Clone)]
  237. #[serde(rename_all = "snake_case")]
  238. #[serde(tag = "kind")]
  239. pub enum KnownCondition {
  240. EventMatch(EventMatchCondition),
  241. ContainsDisplayName,
  242. RoomMemberCount {
  243. #[serde(skip_serializing_if = "Option::is_none")]
  244. is: Option<Cow<'static, str>>,
  245. },
  246. SenderNotificationPermission {
  247. key: Cow<'static, str>,
  248. },
  249. #[serde(rename = "org.matrix.msc3772.relation_match")]
  250. RelationMatch {
  251. rel_type: Cow<'static, str>,
  252. #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
  253. event_type_pattern: Option<Cow<'static, str>>,
  254. #[serde(skip_serializing_if = "Option::is_none")]
  255. sender: Option<Cow<'static, str>>,
  256. #[serde(skip_serializing_if = "Option::is_none")]
  257. sender_type: Option<Cow<'static, str>>,
  258. },
  259. }
  260. impl IntoPy<PyObject> for Condition {
  261. fn into_py(self, py: Python<'_>) -> PyObject {
  262. pythonize(py, &self).expect("valid condition")
  263. }
  264. }
  265. impl<'source> FromPyObject<'source> for Condition {
  266. fn extract(ob: &'source PyAny) -> PyResult<Self> {
  267. Ok(depythonize(ob)?)
  268. }
  269. }
  270. /// The body of a [`Condition::EventMatch`]
  271. #[derive(Serialize, Deserialize, Debug, Clone)]
  272. pub struct EventMatchCondition {
  273. pub key: Cow<'static, str>,
  274. #[serde(skip_serializing_if = "Option::is_none")]
  275. pub pattern: Option<Cow<'static, str>>,
  276. #[serde(skip_serializing_if = "Option::is_none")]
  277. pub pattern_type: Option<Cow<'static, str>>,
  278. }
  279. /// The collection of push rules for a user.
  280. #[derive(Debug, Clone, Default)]
  281. #[pyclass(frozen)]
  282. pub struct PushRules {
  283. /// Custom push rules that override a base rule.
  284. overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
  285. /// Custom rules that come between the prepend/append override base rules.
  286. override_rules: Vec<PushRule>,
  287. /// Custom rules that come before the base content rules.
  288. content: Vec<PushRule>,
  289. /// Custom rules that come before the base room rules.
  290. room: Vec<PushRule>,
  291. /// Custom rules that come before the base sender rules.
  292. sender: Vec<PushRule>,
  293. /// Custom rules that come before the base underride rules.
  294. underride: Vec<PushRule>,
  295. }
  296. #[pymethods]
  297. impl PushRules {
  298. #[new]
  299. pub fn new(rules: Vec<PushRule>) -> PushRules {
  300. let mut push_rules: PushRules = Default::default();
  301. for rule in rules {
  302. if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) {
  303. push_rules.overridden_base_rules.insert(
  304. rule.rule_id.clone(),
  305. PushRule {
  306. actions: rule.actions.clone(),
  307. ..o.clone()
  308. },
  309. );
  310. continue;
  311. }
  312. match rule.priority_class {
  313. 5 => push_rules.override_rules.push(rule),
  314. 4 => push_rules.content.push(rule),
  315. 3 => push_rules.room.push(rule),
  316. 2 => push_rules.sender.push(rule),
  317. 1 => push_rules.underride.push(rule),
  318. _ => {
  319. warn!(
  320. "Unrecognized priority class for rule {}: {}",
  321. rule.rule_id, rule.priority_class
  322. );
  323. }
  324. }
  325. }
  326. push_rules
  327. }
  328. /// Returns the list of all rules, including base rules, in the order they
  329. /// should be executed in.
  330. fn rules(&self) -> Vec<PushRule> {
  331. self.iter().cloned().collect()
  332. }
  333. }
  334. impl PushRules {
  335. /// Iterates over all the rules, including base rules, in the order they
  336. /// should be executed in.
  337. pub fn iter(&self) -> impl Iterator<Item = &PushRule> {
  338. base_rules::BASE_PREPEND_OVERRIDE_RULES
  339. .iter()
  340. .chain(self.override_rules.iter())
  341. .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
  342. .chain(self.content.iter())
  343. .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
  344. .chain(self.room.iter())
  345. .chain(self.sender.iter())
  346. .chain(self.underride.iter())
  347. .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter())
  348. .map(|rule| {
  349. self.overridden_base_rules
  350. .get(&*rule.rule_id)
  351. .unwrap_or(rule)
  352. })
  353. }
  354. }
  355. /// A wrapper around `PushRules` that checks the enabled state of rules and
  356. /// filters out disabled experimental rules.
  357. #[derive(Debug, Clone, Default)]
  358. #[pyclass(frozen)]
  359. pub struct FilteredPushRules {
  360. push_rules: PushRules,
  361. enabled_map: BTreeMap<String, bool>,
  362. msc3786_enabled: bool,
  363. msc3772_enabled: bool,
  364. }
  365. #[pymethods]
  366. impl FilteredPushRules {
  367. #[new]
  368. pub fn py_new(
  369. push_rules: PushRules,
  370. enabled_map: BTreeMap<String, bool>,
  371. msc3786_enabled: bool,
  372. msc3772_enabled: bool,
  373. ) -> Self {
  374. Self {
  375. push_rules,
  376. enabled_map,
  377. msc3786_enabled,
  378. msc3772_enabled,
  379. }
  380. }
  381. /// Returns the list of all rules and their enabled state, including base
  382. /// rules, in the order they should be executed in.
  383. fn rules(&self) -> Vec<(PushRule, bool)> {
  384. self.iter().map(|(r, e)| (r.clone(), e)).collect()
  385. }
  386. }
  387. impl FilteredPushRules {
  388. /// Iterates over all the rules and their enabled state, including base
  389. /// rules, in the order they should be executed in.
  390. fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
  391. self.push_rules
  392. .iter()
  393. .filter(|rule| {
  394. // Ignore disabled experimental push rules
  395. if !self.msc3786_enabled
  396. && rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
  397. {
  398. return false;
  399. }
  400. if !self.msc3772_enabled
  401. && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
  402. {
  403. return false;
  404. }
  405. true
  406. })
  407. .map(|r| {
  408. let enabled = *self
  409. .enabled_map
  410. .get(&*r.rule_id)
  411. .unwrap_or(&r.default_enabled);
  412. (r, enabled)
  413. })
  414. }
  415. }
  416. #[test]
  417. fn test_serialize_condition() {
  418. let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition {
  419. key: "content.body".into(),
  420. pattern: Some("coffee".into()),
  421. pattern_type: None,
  422. }));
  423. let json = serde_json::to_string(&condition).unwrap();
  424. assert_eq!(
  425. json,
  426. r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#
  427. )
  428. }
  429. #[test]
  430. fn test_deserialize_condition() {
  431. let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#;
  432. let _: Condition = serde_json::from_str(json).unwrap();
  433. }
  434. #[test]
  435. fn test_deserialize_custom_condition() {
  436. let json = r#"{"kind":"custom_tag"}"#;
  437. let condition: Condition = serde_json::from_str(json).unwrap();
  438. assert!(matches!(condition, Condition::Unknown(_)));
  439. let new_json = serde_json::to_string(&condition).unwrap();
  440. assert_eq!(json, new_json);
  441. }
  442. #[test]
  443. fn test_deserialize_action() {
  444. let _: Action = serde_json::from_str(r#""notify""#).unwrap();
  445. let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap();
  446. let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
  447. let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
  448. }
  449. #[test]
  450. fn test_custom_action() {
  451. let json = r#"{"some_custom":"action_fields"}"#;
  452. let action: Action = serde_json::from_str(json).unwrap();
  453. assert!(matches!(action, Action::Unknown(_)));
  454. let new_json = serde_json::to_string(&action).unwrap();
  455. assert_eq!(json, new_json);
  456. }