You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

774 lines
26 KiB

  1. // Copyright 2022 The Matrix.org Foundation C.I.C.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //! An implementation of Matrix push rules.
  15. //!
  16. //! The `Cow<_>` type is used extensively within this module to allow creating
  17. //! the base rules as constants (in Rust constants can't require explicit
  18. //! allocation atm).
  19. //!
  20. //! ---
  21. //!
  22. //! Push rules is the system used to determine which events trigger a push (and a
  23. //! bump in notification counts).
  24. //!
  25. //! This consists of a list of "push rules" for each user, where a push rule is a
  26. //! pair of "conditions" and "actions". When a user receives an event Synapse
  27. //! iterates over the list of push rules until it finds one where all the conditions
  28. //! match the event, at which point "actions" describe the outcome (e.g. notify,
  29. //! highlight, etc).
  30. //!
  31. //! Push rules are split up into 5 different "kinds" (aka "priority classes"), which
  32. //! are run in order:
  33. //! 1. Override — highest priority rules, e.g. always ignore notices
  34. //! 2. Content — content specific rules, e.g. @ notifications
  35. //! 3. Room — per room rules, e.g. enable/disable notifications for all messages
  36. //! in a room
  37. //! 4. Sender — per sender rules, e.g. never notify for messages from a given
  38. //! user
  39. //! 5. Underride — the lowest priority "default" rules, e.g. notify for every
  40. //! message.
  41. //!
  42. //! The set of "base rules" are the list of rules that every user has by default. A
  43. //! user can modify their copy of the push rules in one of three ways:
  44. //! 1. Adding a new push rule of a certain kind
  45. //! 2. Changing the actions of a base rule
  46. //! 3. Enabling/disabling a base rule.
  47. //!
  48. //! The base rules are split into whether they come before or after a particular
  49. //! kind, so the order of push rule evaluation would be: base rules for before
  50. //! "override" kind, user defined "override" rules, base rules after "override"
  51. //! kind, etc, etc.
  52. use std::borrow::Cow;
  53. use std::collections::{BTreeMap, HashMap, HashSet};
  54. use anyhow::{Context, Error};
  55. use log::warn;
  56. use pyo3::exceptions::PyTypeError;
  57. use pyo3::prelude::*;
  58. use pyo3::types::{PyBool, PyList, PyLong, PyString};
  59. use pythonize::{depythonize, pythonize};
  60. use serde::de::Error as _;
  61. use serde::{Deserialize, Serialize};
  62. use serde_json::Value;
  63. use self::evaluator::PushRuleEvaluator;
  64. mod base_rules;
  65. pub mod evaluator;
  66. pub mod utils;
  67. /// Called when registering modules with python.
  68. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
  69. let child_module = PyModule::new(py, "push")?;
  70. child_module.add_class::<PushRule>()?;
  71. child_module.add_class::<PushRules>()?;
  72. child_module.add_class::<FilteredPushRules>()?;
  73. child_module.add_class::<PushRuleEvaluator>()?;
  74. child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
  75. m.add_submodule(child_module)?;
  76. // We need to manually add the module to sys.modules to make `from
  77. // synapse.synapse_rust import push` work.
  78. py.import("sys")?
  79. .getattr("modules")?
  80. .set_item("synapse.synapse_rust.push", child_module)?;
  81. Ok(())
  82. }
  83. #[pyfunction]
  84. fn get_base_rule_ids() -> HashSet<&'static str> {
  85. base_rules::BASE_RULES_BY_ID.keys().copied().collect()
  86. }
  87. /// A single push rule for a user.
  88. #[derive(Debug, Clone)]
  89. #[pyclass(frozen)]
  90. pub struct PushRule {
  91. /// A unique ID for this rule
  92. pub rule_id: Cow<'static, str>,
  93. /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python)
  94. #[pyo3(get)]
  95. pub priority_class: i32,
  96. /// The conditions that must all match for actions to be applied
  97. pub conditions: Cow<'static, [Condition]>,
  98. /// The actions to apply if all conditions are met
  99. pub actions: Cow<'static, [Action]>,
  100. /// Whether this is a base rule
  101. #[pyo3(get)]
  102. pub default: bool,
  103. /// Whether this is enabled by default
  104. #[pyo3(get)]
  105. pub default_enabled: bool,
  106. }
  107. #[pymethods]
  108. impl PushRule {
  109. #[staticmethod]
  110. pub fn from_db(
  111. rule_id: String,
  112. priority_class: i32,
  113. conditions: &str,
  114. actions: &str,
  115. ) -> Result<PushRule, Error> {
  116. let conditions = serde_json::from_str(conditions).context("parsing conditions")?;
  117. let actions = serde_json::from_str(actions).context("parsing actions")?;
  118. Ok(PushRule {
  119. rule_id: Cow::Owned(rule_id),
  120. priority_class,
  121. conditions,
  122. actions,
  123. default: false,
  124. default_enabled: true,
  125. })
  126. }
  127. #[getter]
  128. fn rule_id(&self) -> &str {
  129. &self.rule_id
  130. }
  131. #[getter]
  132. fn actions(&self) -> Vec<Action> {
  133. self.actions.clone().into_owned()
  134. }
  135. #[getter]
  136. fn conditions(&self) -> Vec<Condition> {
  137. self.conditions.clone().into_owned()
  138. }
  139. fn __repr__(&self) -> String {
  140. format!(
  141. "<PushRule rule_id={}, conditions={:?}, actions={:?}>",
  142. self.rule_id, self.conditions, self.actions
  143. )
  144. }
  145. }
  146. /// The "action" Synapse should perform for a matching push rule.
  147. #[derive(Debug, Clone, PartialEq, Eq)]
  148. pub enum Action {
  149. Notify,
  150. SetTweak(SetTweak),
  151. // Legacy actions that should be understood, but are equivalent to no-ops.
  152. DontNotify,
  153. Coalesce,
  154. // An unrecognized custom action.
  155. Unknown(Value),
  156. }
  157. impl IntoPy<PyObject> for Action {
  158. fn into_py(self, py: Python<'_>) -> PyObject {
  159. // When we pass the `Action` struct to Python we want it to be converted
  160. // to a dict. We use `pythonize`, which converts the struct using the
  161. // `serde` serialization.
  162. pythonize(py, &self).expect("valid action")
  163. }
  164. }
  165. /// The body of a `SetTweak` push action.
  166. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  167. pub struct SetTweak {
  168. set_tweak: Cow<'static, str>,
  169. #[serde(skip_serializing_if = "Option::is_none")]
  170. value: Option<TweakValue>,
  171. // This picks up any other fields that may have been added by clients.
  172. // These get added when we convert the `Action` to a python object.
  173. #[serde(flatten)]
  174. other_keys: Value,
  175. }
  176. /// The value of a `set_tweak`.
  177. ///
  178. /// We need this (rather than using `TweakValue` directly) so that we can use
  179. /// `&'static str` in the value when defining the constant base rules.
  180. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  181. #[serde(untagged)]
  182. pub enum TweakValue {
  183. String(Cow<'static, str>),
  184. Other(Value),
  185. }
  186. impl Serialize for Action {
  187. fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  188. where
  189. S: serde::Serializer,
  190. {
  191. match self {
  192. Action::DontNotify => serializer.serialize_str("dont_notify"),
  193. Action::Notify => serializer.serialize_str("notify"),
  194. Action::Coalesce => serializer.serialize_str("coalesce"),
  195. Action::SetTweak(tweak) => tweak.serialize(serializer),
  196. Action::Unknown(value) => value.serialize(serializer),
  197. }
  198. }
  199. }
  200. /// Simple helper class for deserializing Action from JSON.
  201. #[derive(Deserialize)]
  202. #[serde(untagged)]
  203. enum ActionDeserializeHelper {
  204. Str(String),
  205. SetTweak(SetTweak),
  206. Unknown(Value),
  207. }
  208. impl<'de> Deserialize<'de> for Action {
  209. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  210. where
  211. D: serde::Deserializer<'de>,
  212. {
  213. let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?;
  214. match helper {
  215. ActionDeserializeHelper::Str(s) => match &*s {
  216. "dont_notify" => Ok(Action::DontNotify),
  217. "notify" => Ok(Action::Notify),
  218. "coalesce" => Ok(Action::Coalesce),
  219. _ => Err(D::Error::custom("unrecognized action")),
  220. },
  221. ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)),
  222. ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)),
  223. }
  224. }
  225. }
  226. /// A simple JSON values (string, int, boolean, or null).
  227. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  228. #[serde(untagged)]
  229. pub enum SimpleJsonValue {
  230. Str(Cow<'static, str>),
  231. Int(i64),
  232. Bool(bool),
  233. Null,
  234. }
  235. impl<'source> FromPyObject<'source> for SimpleJsonValue {
  236. fn extract(ob: &'source PyAny) -> PyResult<Self> {
  237. if let Ok(s) = <PyString as pyo3::PyTryFrom>::try_from(ob) {
  238. Ok(SimpleJsonValue::Str(Cow::Owned(s.to_string())))
  239. // A bool *is* an int, ensure we try bool first.
  240. } else if let Ok(b) = <PyBool as pyo3::PyTryFrom>::try_from(ob) {
  241. Ok(SimpleJsonValue::Bool(b.extract()?))
  242. } else if let Ok(i) = <PyLong as pyo3::PyTryFrom>::try_from(ob) {
  243. Ok(SimpleJsonValue::Int(i.extract()?))
  244. } else if ob.is_none() {
  245. Ok(SimpleJsonValue::Null)
  246. } else {
  247. Err(PyTypeError::new_err(format!(
  248. "Can't convert from {} to SimpleJsonValue",
  249. ob.get_type().name()?
  250. )))
  251. }
  252. }
  253. }
  254. /// A JSON values (list, string, int, boolean, or null).
  255. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
  256. #[serde(untagged)]
  257. pub enum JsonValue {
  258. Array(Vec<SimpleJsonValue>),
  259. Value(SimpleJsonValue),
  260. }
  261. impl<'source> FromPyObject<'source> for JsonValue {
  262. fn extract(ob: &'source PyAny) -> PyResult<Self> {
  263. if let Ok(l) = <PyList as pyo3::PyTryFrom>::try_from(ob) {
  264. match l.iter().map(SimpleJsonValue::extract).collect() {
  265. Ok(a) => Ok(JsonValue::Array(a)),
  266. Err(e) => Err(PyTypeError::new_err(format!(
  267. "Can't convert to JsonValue::Array: {}",
  268. e
  269. ))),
  270. }
  271. } else if let Ok(v) = SimpleJsonValue::extract(ob) {
  272. Ok(JsonValue::Value(v))
  273. } else {
  274. Err(PyTypeError::new_err(format!(
  275. "Can't convert from {} to JsonValue",
  276. ob.get_type().name()?
  277. )))
  278. }
  279. }
  280. }
  281. /// A condition used in push rules to match against an event.
  282. ///
  283. /// We need this split as `serde` doesn't give us the ability to have a
  284. /// "catchall" variant in tagged enums.
  285. #[derive(Serialize, Deserialize, Debug, Clone)]
  286. #[serde(untagged)]
  287. pub enum Condition {
  288. /// A recognized condition that we can match against
  289. Known(KnownCondition),
  290. /// An unrecognized condition that we ignore.
  291. Unknown(Value),
  292. }
  293. /// The set of "known" conditions that we can handle.
  294. #[derive(Serialize, Deserialize, Debug, Clone)]
  295. #[serde(rename_all = "snake_case")]
  296. #[serde(tag = "kind")]
  297. pub enum KnownCondition {
  298. EventMatch(EventMatchCondition),
  299. // Identical to event_match but gives predefined patterns. Cannot be added by users.
  300. #[serde(skip_deserializing, rename = "event_match")]
  301. EventMatchType(EventMatchTypeCondition),
  302. EventPropertyIs(EventPropertyIsCondition),
  303. #[serde(rename = "im.nheko.msc3664.related_event_match")]
  304. RelatedEventMatch(RelatedEventMatchCondition),
  305. // Identical to related_event_match but gives predefined patterns. Cannot be added by users.
  306. #[serde(skip_deserializing, rename = "im.nheko.msc3664.related_event_match")]
  307. RelatedEventMatchType(RelatedEventMatchTypeCondition),
  308. EventPropertyContains(EventPropertyIsCondition),
  309. // Identical to exact_event_property_contains but gives predefined patterns. Cannot be added by users.
  310. #[serde(skip_deserializing, rename = "event_property_contains")]
  311. ExactEventPropertyContainsType(EventPropertyIsTypeCondition),
  312. ContainsDisplayName,
  313. RoomMemberCount {
  314. #[serde(skip_serializing_if = "Option::is_none")]
  315. is: Option<Cow<'static, str>>,
  316. },
  317. SenderNotificationPermission {
  318. key: Cow<'static, str>,
  319. },
  320. #[serde(rename = "org.matrix.msc3931.room_version_supports")]
  321. RoomVersionSupports {
  322. feature: Cow<'static, str>,
  323. },
  324. }
  325. impl IntoPy<PyObject> for Condition {
  326. fn into_py(self, py: Python<'_>) -> PyObject {
  327. pythonize(py, &self).expect("valid condition")
  328. }
  329. }
  330. impl<'source> FromPyObject<'source> for Condition {
  331. fn extract(ob: &'source PyAny) -> PyResult<Self> {
  332. Ok(depythonize(ob)?)
  333. }
  334. }
  335. /// The body of a [`Condition::EventMatch`] with a pattern.
  336. #[derive(Serialize, Deserialize, Debug, Clone)]
  337. pub struct EventMatchCondition {
  338. pub key: Cow<'static, str>,
  339. pub pattern: Cow<'static, str>,
  340. }
  341. #[derive(Serialize, Debug, Clone)]
  342. #[serde(rename_all = "snake_case")]
  343. pub enum EventMatchPatternType {
  344. UserId,
  345. UserLocalpart,
  346. }
  347. /// The body of a [`Condition::EventMatch`] that uses user_id or user_localpart as a pattern.
  348. #[derive(Serialize, Debug, Clone)]
  349. pub struct EventMatchTypeCondition {
  350. pub key: Cow<'static, str>,
  351. // During serialization, the pattern_type property gets replaced with a
  352. // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user.
  353. pub pattern_type: Cow<'static, EventMatchPatternType>,
  354. }
  355. /// The body of a [`Condition::EventPropertyIs`]
  356. #[derive(Serialize, Deserialize, Debug, Clone)]
  357. pub struct EventPropertyIsCondition {
  358. pub key: Cow<'static, str>,
  359. pub value: Cow<'static, SimpleJsonValue>,
  360. }
  361. /// The body of a [`Condition::EventPropertyIs`] that uses user_id or user_localpart as a pattern.
  362. #[derive(Serialize, Debug, Clone)]
  363. pub struct EventPropertyIsTypeCondition {
  364. pub key: Cow<'static, str>,
  365. // During serialization, the pattern_type property gets replaced with a
  366. // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user.
  367. pub value_type: Cow<'static, EventMatchPatternType>,
  368. }
  369. /// The body of a [`Condition::RelatedEventMatch`]
  370. #[derive(Serialize, Deserialize, Debug, Clone)]
  371. pub struct RelatedEventMatchCondition {
  372. #[serde(skip_serializing_if = "Option::is_none")]
  373. pub key: Option<Cow<'static, str>>,
  374. #[serde(skip_serializing_if = "Option::is_none")]
  375. pub pattern: Option<Cow<'static, str>>,
  376. pub rel_type: Cow<'static, str>,
  377. #[serde(skip_serializing_if = "Option::is_none")]
  378. pub include_fallbacks: Option<bool>,
  379. }
  380. /// The body of a [`Condition::RelatedEventMatch`] that uses user_id or user_localpart as a pattern.
  381. #[derive(Serialize, Debug, Clone)]
  382. pub struct RelatedEventMatchTypeCondition {
  383. // This is only used if pattern_type exists (and thus key must exist), so is
  384. // a bit simpler than RelatedEventMatchCondition.
  385. pub key: Cow<'static, str>,
  386. pub pattern_type: Cow<'static, EventMatchPatternType>,
  387. pub rel_type: Cow<'static, str>,
  388. #[serde(skip_serializing_if = "Option::is_none")]
  389. pub include_fallbacks: Option<bool>,
  390. }
  391. /// The collection of push rules for a user.
  392. #[derive(Debug, Clone, Default)]
  393. #[pyclass(frozen)]
  394. pub struct PushRules {
  395. /// Custom push rules that override a base rule.
  396. overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
  397. /// Custom rules that come between the prepend/append override base rules.
  398. override_rules: Vec<PushRule>,
  399. /// Custom rules that come before the base content rules.
  400. content: Vec<PushRule>,
  401. /// Custom rules that come before the base room rules.
  402. room: Vec<PushRule>,
  403. /// Custom rules that come before the base sender rules.
  404. sender: Vec<PushRule>,
  405. /// Custom rules that come before the base underride rules.
  406. underride: Vec<PushRule>,
  407. }
  408. #[pymethods]
  409. impl PushRules {
  410. #[new]
  411. pub fn new(rules: Vec<PushRule>) -> PushRules {
  412. let mut push_rules: PushRules = Default::default();
  413. for rule in rules {
  414. if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) {
  415. push_rules.overridden_base_rules.insert(
  416. rule.rule_id.clone(),
  417. PushRule {
  418. actions: rule.actions.clone(),
  419. ..o.clone()
  420. },
  421. );
  422. continue;
  423. }
  424. match rule.priority_class {
  425. 5 => push_rules.override_rules.push(rule),
  426. 4 => push_rules.content.push(rule),
  427. 3 => push_rules.room.push(rule),
  428. 2 => push_rules.sender.push(rule),
  429. 1 => push_rules.underride.push(rule),
  430. _ => {
  431. warn!(
  432. "Unrecognized priority class for rule {}: {}",
  433. rule.rule_id, rule.priority_class
  434. );
  435. }
  436. }
  437. }
  438. push_rules
  439. }
  440. /// Returns the list of all rules, including base rules, in the order they
  441. /// should be executed in.
  442. fn rules(&self) -> Vec<PushRule> {
  443. self.iter().cloned().collect()
  444. }
  445. }
  446. impl PushRules {
  447. /// Iterates over all the rules, including base rules, in the order they
  448. /// should be executed in.
  449. pub fn iter(&self) -> impl Iterator<Item = &PushRule> {
  450. base_rules::BASE_PREPEND_OVERRIDE_RULES
  451. .iter()
  452. .chain(self.override_rules.iter())
  453. .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
  454. .chain(self.content.iter())
  455. .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
  456. .chain(self.room.iter())
  457. .chain(self.sender.iter())
  458. .chain(self.underride.iter())
  459. .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter())
  460. .map(|rule| {
  461. self.overridden_base_rules
  462. .get(&*rule.rule_id)
  463. .unwrap_or(rule)
  464. })
  465. }
  466. }
  467. /// A wrapper around `PushRules` that checks the enabled state of rules and
  468. /// filters out disabled experimental rules.
  469. #[derive(Debug, Clone, Default)]
  470. #[pyclass(frozen)]
  471. pub struct FilteredPushRules {
  472. push_rules: PushRules,
  473. enabled_map: BTreeMap<String, bool>,
  474. msc1767_enabled: bool,
  475. msc3381_polls_enabled: bool,
  476. msc3664_enabled: bool,
  477. msc4028_push_encrypted_events: bool,
  478. }
  479. #[pymethods]
  480. impl FilteredPushRules {
  481. #[new]
  482. pub fn py_new(
  483. push_rules: PushRules,
  484. enabled_map: BTreeMap<String, bool>,
  485. msc1767_enabled: bool,
  486. msc3381_polls_enabled: bool,
  487. msc3664_enabled: bool,
  488. msc4028_push_encrypted_events: bool,
  489. ) -> Self {
  490. Self {
  491. push_rules,
  492. enabled_map,
  493. msc1767_enabled,
  494. msc3381_polls_enabled,
  495. msc3664_enabled,
  496. msc4028_push_encrypted_events,
  497. }
  498. }
  499. /// Returns the list of all rules and their enabled state, including base
  500. /// rules, in the order they should be executed in.
  501. fn rules(&self) -> Vec<(PushRule, bool)> {
  502. self.iter().map(|(r, e)| (r.clone(), e)).collect()
  503. }
  504. }
  505. impl FilteredPushRules {
  506. /// Iterates over all the rules and their enabled state, including base
  507. /// rules, in the order they should be executed in.
  508. fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
  509. self.push_rules
  510. .iter()
  511. .filter(|rule| {
  512. // Ignore disabled experimental push rules
  513. if !self.msc1767_enabled
  514. && (rule.rule_id.contains("org.matrix.msc1767")
  515. || rule.rule_id.contains("org.matrix.msc3933"))
  516. {
  517. return false;
  518. }
  519. if !self.msc3664_enabled
  520. && rule.rule_id == "global/override/.im.nheko.msc3664.reply"
  521. {
  522. return false;
  523. }
  524. if !self.msc3381_polls_enabled && rule.rule_id.contains("org.matrix.msc3930") {
  525. return false;
  526. }
  527. if !self.msc4028_push_encrypted_events
  528. && rule.rule_id == "global/override/.org.matrix.msc4028.encrypted_event"
  529. {
  530. return false;
  531. }
  532. true
  533. })
  534. .map(|r| {
  535. let enabled = *self
  536. .enabled_map
  537. .get(&*r.rule_id)
  538. .unwrap_or(&r.default_enabled);
  539. (r, enabled)
  540. })
  541. }
  542. }
  543. #[test]
  544. fn test_serialize_condition() {
  545. let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition {
  546. key: "content.body".into(),
  547. pattern: "coffee".into(),
  548. }));
  549. let json = serde_json::to_string(&condition).unwrap();
  550. assert_eq!(
  551. json,
  552. r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#
  553. )
  554. }
  555. #[test]
  556. fn test_deserialize_condition() {
  557. let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#;
  558. let condition: Condition = serde_json::from_str(json).unwrap();
  559. assert!(matches!(
  560. condition,
  561. Condition::Known(KnownCondition::EventMatch(_))
  562. ));
  563. }
  564. #[test]
  565. fn test_serialize_event_match_condition_with_pattern_type() {
  566. let condition = Condition::Known(KnownCondition::EventMatchType(EventMatchTypeCondition {
  567. key: "content.body".into(),
  568. pattern_type: Cow::Owned(EventMatchPatternType::UserId),
  569. }));
  570. let json = serde_json::to_string(&condition).unwrap();
  571. assert_eq!(
  572. json,
  573. r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"#
  574. )
  575. }
  576. #[test]
  577. fn test_cannot_deserialize_event_match_condition_with_pattern_type() {
  578. let json = r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"#;
  579. let condition: Condition = serde_json::from_str(json).unwrap();
  580. assert!(matches!(condition, Condition::Unknown(_)));
  581. }
  582. #[test]
  583. fn test_deserialize_unstable_msc3664_condition() {
  584. let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#;
  585. let condition: Condition = serde_json::from_str(json).unwrap();
  586. assert!(matches!(
  587. condition,
  588. Condition::Known(KnownCondition::RelatedEventMatch(_))
  589. ));
  590. }
  591. #[test]
  592. fn test_serialize_unstable_msc3664_condition_with_pattern_type() {
  593. let condition = Condition::Known(KnownCondition::RelatedEventMatchType(
  594. RelatedEventMatchTypeCondition {
  595. key: "content.body".into(),
  596. pattern_type: Cow::Owned(EventMatchPatternType::UserId),
  597. rel_type: "m.in_reply_to".into(),
  598. include_fallbacks: Some(true),
  599. },
  600. ));
  601. let json = serde_json::to_string(&condition).unwrap();
  602. assert_eq!(
  603. json,
  604. r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to","include_fallbacks":true}"#
  605. )
  606. }
  607. #[test]
  608. fn test_cannot_deserialize_unstable_msc3664_condition_with_pattern_type() {
  609. let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to"}"#;
  610. let condition: Condition = serde_json::from_str(json).unwrap();
  611. // Since pattern is optional on RelatedEventMatch it deserializes it to that
  612. // instead of RelatedEventMatchType.
  613. assert!(matches!(
  614. condition,
  615. Condition::Known(KnownCondition::RelatedEventMatch(_))
  616. ));
  617. }
  618. #[test]
  619. fn test_deserialize_unstable_msc3931_condition() {
  620. let json =
  621. r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#;
  622. let condition: Condition = serde_json::from_str(json).unwrap();
  623. assert!(matches!(
  624. condition,
  625. Condition::Known(KnownCondition::RoomVersionSupports { feature: _ })
  626. ));
  627. }
  628. #[test]
  629. fn test_deserialize_event_property_is_condition() {
  630. // A string condition should work.
  631. let json = r#"{"kind":"event_property_is","key":"content.value","value":"foo"}"#;
  632. let condition: Condition = serde_json::from_str(json).unwrap();
  633. assert!(matches!(
  634. condition,
  635. Condition::Known(KnownCondition::EventPropertyIs(_))
  636. ));
  637. // A boolean condition should work.
  638. let json = r#"{"kind":"event_property_is","key":"content.value","value":true}"#;
  639. let condition: Condition = serde_json::from_str(json).unwrap();
  640. assert!(matches!(
  641. condition,
  642. Condition::Known(KnownCondition::EventPropertyIs(_))
  643. ));
  644. // An integer condition should work.
  645. let json = r#"{"kind":"event_property_is","key":"content.value","value":1}"#;
  646. let condition: Condition = serde_json::from_str(json).unwrap();
  647. assert!(matches!(
  648. condition,
  649. Condition::Known(KnownCondition::EventPropertyIs(_))
  650. ));
  651. // A null condition should work
  652. let json = r#"{"kind":"event_property_is","key":"content.value","value":null}"#;
  653. let condition: Condition = serde_json::from_str(json).unwrap();
  654. assert!(matches!(
  655. condition,
  656. Condition::Known(KnownCondition::EventPropertyIs(_))
  657. ));
  658. }
  659. #[test]
  660. fn test_deserialize_custom_condition() {
  661. let json = r#"{"kind":"custom_tag"}"#;
  662. let condition: Condition = serde_json::from_str(json).unwrap();
  663. assert!(matches!(condition, Condition::Unknown(_)));
  664. let new_json = serde_json::to_string(&condition).unwrap();
  665. assert_eq!(json, new_json);
  666. }
  667. #[test]
  668. fn test_deserialize_action() {
  669. let _: Action = serde_json::from_str(r#""notify""#).unwrap();
  670. let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap();
  671. let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
  672. let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
  673. }
  674. #[test]
  675. fn test_custom_action() {
  676. let json = r#"{"some_custom":"action_fields"}"#;
  677. let action: Action = serde_json::from_str(json).unwrap();
  678. assert!(matches!(action, Action::Unknown(_)));
  679. let new_json = serde_json::to_string(&action).unwrap();
  680. assert_eq!(json, new_json);
  681. }