Browse Source

Support for MSC3758: exact_event_match push condition (#14964)

This specifies to search for an exact value match, instead of
string globbing. It only works across non-compound JSON values
(null, boolean, integer, and strings).
tags/v1.78.0rc1
Patrick Cloke 1 year ago
committed by GitHub
parent
commit
14be78d492
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 356 additions and 41 deletions
  1. +1
    -0
      changelog.d/14964.feature
  2. +53
    -12
      rust/benches/evaluator.rs
  3. +54
    -15
      rust/src/push/evaluator.rs
  4. +83
    -0
      rust/src/push/mod.rs
  5. +4
    -3
      stubs/synapse/synapse_rust/push.pyi
  6. +5
    -0
      synapse/config/experimental.py
  7. +11
    -7
      synapse/push/bulk_push_rule_evaluator.py
  8. +2
    -0
      synapse/types/__init__.py
  9. +143
    -4
      tests/push/test_push_rule_evaluator.py

+ 1
- 0
changelog.d/14964.feature View File

@@ -0,0 +1 @@
Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758).

+ 53
- 12
rust/benches/evaluator.rs View File

@@ -16,6 +16,7 @@
use std::collections::BTreeSet;
use synapse::push::{
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
SimpleJsonValue,
};
use test::Bencher;

@@ -24,9 +25,18 @@ extern crate test;
#[bench]
fn bench_match_exact(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
),
]
.into_iter()
.collect();
@@ -43,6 +53,7 @@ fn bench_match_exact(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();

@@ -63,9 +74,18 @@ fn bench_match_exact(b: &mut Bencher) {
#[bench]
fn bench_match_word(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
),
]
.into_iter()
.collect();
@@ -82,6 +102,7 @@ fn bench_match_word(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();

@@ -102,9 +123,18 @@ fn bench_match_word(b: &mut Bencher) {
#[bench]
fn bench_match_word_miss(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
),
]
.into_iter()
.collect();
@@ -121,6 +151,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();

@@ -141,9 +172,18 @@ fn bench_match_word_miss(b: &mut Bencher) {
#[bench]
fn bench_eval_message(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
),
]
.into_iter()
.collect();
@@ -160,6 +200,7 @@ fn bench_eval_message(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();



+ 54
- 15
rust/src/push/evaluator.rs View File

@@ -22,8 +22,8 @@ use regex::Regex;

use super::{
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
RelatedEventMatchCondition,
Action, Condition, EventMatchCondition, ExactEventMatchCondition, FilteredPushRules,
KnownCondition, RelatedEventMatchCondition, SimpleJsonValue,
};

lazy_static! {
@@ -61,9 +61,9 @@ impl RoomVersionFeatures {
/// Allows running a set of push rules against a particular event.
#[pyclass]
pub struct PushRuleEvaluator {
/// A mapping of "flattened" keys to string values in the event, e.g.
/// A mapping of "flattened" keys to simple JSON values in the event, e.g.
/// includes things like "type" and "content.msgtype".
flattened_keys: BTreeMap<String, String>,
flattened_keys: BTreeMap<String, SimpleJsonValue>,

/// The "content.body", if any.
body: String,
@@ -87,7 +87,7 @@ pub struct PushRuleEvaluator {

/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,

/// If msc3664, push rules for related events, is enabled.
related_event_match_enabled: bool,
@@ -98,6 +98,9 @@ pub struct PushRuleEvaluator {
/// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,

/// If MSC3758 (exact_event_match push rule condition) is enabled.
msc3758_exact_event_match: bool,
}

#[pymethods]
@@ -106,22 +109,23 @@ impl PushRuleEvaluator {
#[allow(clippy::too_many_arguments)]
#[new]
pub fn py_new(
flattened_keys: BTreeMap<String, String>,
flattened_keys: BTreeMap<String, SimpleJsonValue>,
has_mentions: bool,
user_mentions: BTreeSet<String>,
room_mention: bool,
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,
related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
msc3758_exact_event_match: bool,
) -> Result<Self, Error> {
let body = flattened_keys
.get("content.body")
.cloned()
.unwrap_or_default();
let body = match flattened_keys.get("content.body") {
Some(SimpleJsonValue::Str(s)) => s.clone(),
_ => String::new(),
};

Ok(PushRuleEvaluator {
flattened_keys,
@@ -136,6 +140,7 @@ impl PushRuleEvaluator {
related_event_match_enabled,
room_version_feature_flags,
msc3931_enabled,
msc3758_exact_event_match,
})
}

@@ -252,6 +257,9 @@ impl PushRuleEvaluator {
KnownCondition::EventMatch(event_match) => {
self.match_event_match(event_match, user_id)?
}
KnownCondition::ExactEventMatch(exact_event_match) => {
self.match_exact_event_match(exact_event_match)?
}
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
@@ -337,7 +345,9 @@ impl PushRuleEvaluator {
return Ok(false);
};

let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) {
let haystack = if let Some(SimpleJsonValue::Str(haystack)) =
self.flattened_keys.get(&*event_match.key)
{
haystack
} else {
return Ok(false);
@@ -355,6 +365,27 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}

/// Evaluates a `exact_event_match` condition. (MSC3758)
fn match_exact_event_match(
&self,
exact_event_match: &ExactEventMatchCondition,
) -> Result<bool, Error> {
// First check if the feature is enabled.
if !self.msc3758_exact_event_match {
return Ok(false);
}

let value = &exact_event_match.value;

let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) {
haystack
} else {
return Ok(false);
};

Ok(haystack == &**value)
}

/// Evaluates a `related_event_match` condition. (MSC3664)
fn match_related_event_match(
&self,
@@ -410,7 +441,7 @@ impl PushRuleEvaluator {
return Ok(false);
};

let haystack = if let Some(haystack) = event.get(&**key) {
let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) {
haystack
} else {
return Ok(false);
@@ -455,7 +486,10 @@ impl PushRuleEvaluator {
#[test]
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
flattened_keys.insert(
"content.body".to_string(),
SimpleJsonValue::Str("foo bar bob hello".to_string()),
);
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
false,
@@ -468,6 +502,7 @@ fn push_rule_evaluator() {
true,
vec![],
true,
true,
)
.unwrap();

@@ -482,7 +517,10 @@ fn test_requires_room_version_supports_condition() {
use crate::push::{PushRule, PushRules};

let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
flattened_keys.insert(
"content.body".to_string(),
SimpleJsonValue::Str("foo bar bob hello".to_string()),
);
let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
@@ -496,6 +534,7 @@ fn test_requires_room_version_supports_condition() {
false,
flags,
true,
true,
)
.unwrap();



+ 83
- 0
rust/src/push/mod.rs View File

@@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet};

use anyhow::{Context, Error};
use log::warn;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyLong, PyString};
use pythonize::{depythonize, pythonize};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
@@ -248,6 +250,36 @@ impl<'de> Deserialize<'de> for Action {
}
}

/// A simple JSON values (string, int, boolean, or null).
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum SimpleJsonValue {
Str(String),
Int(i64),
Bool(bool),
Null,
}

impl<'source> FromPyObject<'source> for SimpleJsonValue {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok(s) = <PyString as pyo3::PyTryFrom>::try_from(ob) {
Ok(SimpleJsonValue::Str(s.to_string()))
// A bool *is* an int, ensure we try bool first.
} else if let Ok(b) = <PyBool as pyo3::PyTryFrom>::try_from(ob) {
Ok(SimpleJsonValue::Bool(b.extract()?))
} else if let Ok(i) = <PyLong as pyo3::PyTryFrom>::try_from(ob) {
Ok(SimpleJsonValue::Int(i.extract()?))
} else if ob.is_none() {
Ok(SimpleJsonValue::Null)
} else {
Err(PyTypeError::new_err(format!(
"Can't convert from {} to SimpleJsonValue",
ob.get_type().name()?
)))
}
}
}

/// A condition used in push rules to match against an event.
///
/// We need this split as `serde` doesn't give us the ability to have a
@@ -267,6 +299,8 @@ pub enum Condition {
#[serde(tag = "kind")]
pub enum KnownCondition {
EventMatch(EventMatchCondition),
#[serde(rename = "com.beeper.msc3758.exact_event_match")]
ExactEventMatch(ExactEventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
#[serde(rename = "org.matrix.msc3952.is_user_mention")]
@@ -309,6 +343,13 @@ pub struct EventMatchCondition {
pub pattern_type: Option<Cow<'static, str>>,
}

/// The body of a [`Condition::ExactEventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ExactEventMatchCondition {
pub key: Cow<'static, str>,
pub value: Cow<'static, SimpleJsonValue>,
}

/// The body of a [`Condition::RelatedEventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RelatedEventMatchCondition {
@@ -542,6 +583,48 @@ fn test_deserialize_unstable_msc3931_condition() {
));
}

#[test]
fn test_deserialize_unstable_msc3758_condition() {
// A string condition should work.
let json =
r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::ExactEventMatch(_))
));

// A boolean condition should work.
let json =
r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::ExactEventMatch(_))
));

// An integer condition should work.
let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::ExactEventMatch(_))
));

// A null condition should work
let json =
r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::ExactEventMatch(_))
));
}

#[test]
fn test_deserialize_unstable_msc3952_user_condition() {
let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#;


+ 4
- 3
stubs/synapse/synapse_rust/push.pyi View File

@@ -14,7 +14,7 @@

from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union

from synapse.types import JsonDict
from synapse.types import JsonDict, SimpleJsonValue

class PushRule:
@property
@@ -56,17 +56,18 @@ def get_base_rule_ids() -> Collection[str]: ...
class PushRuleEvaluator:
def __init__(
self,
flattened_keys: Mapping[str, str],
flattened_keys: Mapping[str, SimpleJsonValue],
has_mentions: bool,
user_mentions: Set[str],
room_mention: bool,
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
related_events_flattened: Mapping[str, Mapping[str, str]],
related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]],
related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
msc3758_exact_event_match: bool,
): ...
def run(
self,


+ 5
- 0
synapse/config/experimental.py View File

@@ -169,6 +169,11 @@ class ExperimentalConfig(Config):
# MSC3925: do not replace events with their edits
self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False)

# MSC3758: exact_event_match push rule condition
self.msc3758_exact_event_match = experimental.get(
"msc3758_exact_event_match", False
)

# MSC3873: Disambiguate event_match keys.
self.msc3783_escape_event_match_key = experimental.get(
"msc3783_escape_event_match_key", False


+ 11
- 7
synapse/push/bulk_push_rule_evaluator.py View File

@@ -43,6 +43,7 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.types import SimpleJsonValue
from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
@@ -256,13 +257,15 @@ class BulkPushRuleEvaluator:

return pl_event.content if pl_event else {}, sender_level

async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
async def _related_events(
self, event: EventBase
) -> Dict[str, Dict[str, SimpleJsonValue]]:
"""Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation

Returns:
Mapping of relation type to flattened events.
"""
related_events: Dict[str, Dict[str, str]] = {}
related_events: Dict[str, Dict[str, SimpleJsonValue]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
@@ -425,6 +428,7 @@ class BulkPushRuleEvaluator:
self._related_event_match_enabled,
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc3758_exact_event_match,
)

users = rules_by_user.keys()
@@ -501,15 +505,15 @@ StateGroup = Union[object, int]
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
result: Optional[Dict[str, SimpleJsonValue]] = None,
*,
msc3783_escape_event_match_key: bool = False,
) -> Dict[str, str]:
) -> Dict[str, SimpleJsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
flatten it into a single layer dictionary by combining the keys & sub-keys.

Any (non-dictionary), non-string value is dropped.
String, integer, boolean, and null values are kept. All others are dropped.

Transforms:

@@ -538,8 +542,8 @@ def _flatten_dict(
# nested fields.
key = key.replace("\\", "\\\\").replace(".", "\\.")

if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
if isinstance(value, (bool, str)) or type(value) is int or value is None:
result[".".join(prefix + [key])] = value
elif isinstance(value, Mapping):
# do not set `room_version` due to recursion considerations below
_flatten_dict(


+ 2
- 0
synapse/types/__init__.py View File

@@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]

# JSON types. These could be made stronger, but will do for now.
# A "simple" (canonical) JSON value.
SimpleJsonValue = Optional[Union[str, int, bool]]
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.


+ 143
- 4
tests/push/test_push_rule_evaluator.py View File

@@ -57,7 +57,7 @@ class FlattenDictTestCase(unittest.TestCase):
)

def test_non_string(self) -> None:
"""Non-string items are dropped."""
"""Booleans, ints, and nulls should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
@@ -66,7 +66,9 @@ class FlattenDictTestCase(unittest.TestCase):
"fuzz": [],
"boo": {},
}
self.assertEqual({"woo": "woo"}, _flatten_dict(input))
self.assertEqual(
{"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input)
)

def test_event(self) -> None:
"""Events can also be flattened."""
@@ -86,9 +88,9 @@ class FlattenDictTestCase(unittest.TestCase):
)
expected = {
"content.msgtype": "m.text",
"content.body": "hello world!",
"content.body": "Hello world!",
"content.format": "org.matrix.custom.html",
"content.formatted_body": "<h1>hello world!</h1>",
"content.formatted_body": "<h1>Hello world!</h1>",
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
@@ -166,6 +168,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
msc3758_exact_event_match=True,
)

def test_display_name(self) -> None:
@@ -410,6 +413,142 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)

def test_exact_event_match_string(self) -> None:
"""Check that exact_event_match conditions work as expected for strings."""

# Test against a string value.
condition = {
"kind": "com.beeper.msc3758.exact_event_match",
"key": "content.value",
"value": "foobaz",
}
self._assert_matches(
condition,
{"value": "foobaz"},
"exact value should match",
)
self._assert_not_matches(
condition,
{"value": "FoobaZ"},
"values should match and be case-sensitive",
)
self._assert_not_matches(
condition,
{"value": "test foobaz test"},
"values must exactly match",
)
value: Any
for value in (True, False, 1, 1.1, None, [], {}):
self._assert_not_matches(
condition,
{"value": value},
"incorrect types should not match",
)

# it should work on frozendicts too
self._assert_matches(
condition,
frozendict.frozendict({"value": "foobaz"}),
"values should match on frozendicts",
)

def test_exact_event_match_boolean(self) -> None:
"""Check that exact_event_match conditions work as expected for booleans."""

# Test against a True boolean value.
condition = {
"kind": "com.beeper.msc3758.exact_event_match",
"key": "content.value",
"value": True,
}
self._assert_matches(
condition,
{"value": True},
"exact value should match",
)
self._assert_not_matches(
condition,
{"value": False},
"incorrect values should not match",
)
for value in ("foobaz", 1, 1.1, None, [], {}):
self._assert_not_matches(
condition,
{"value": value},
"incorrect types should not match",
)

# Test against a False boolean value.
condition = {
"kind": "com.beeper.msc3758.exact_event_match",
"key": "content.value",
"value": False,
}
self._assert_matches(
condition,
{"value": False},
"exact value should match",
)
self._assert_not_matches(
condition,
{"value": True},
"incorrect values should not match",
)
# Choose false-y values to ensure there's no type coercion.
for value in ("", 0, 1.1, None, [], {}):
self._assert_not_matches(
condition,
{"value": value},
"incorrect types should not match",
)

def test_exact_event_match_null(self) -> None:
"""Check that exact_event_match conditions work as expected for null."""

condition = {
"kind": "com.beeper.msc3758.exact_event_match",
"key": "content.value",
"value": None,
}
self._assert_matches(
condition,
{"value": None},
"exact value should match",
)
for value in ("foobaz", True, False, 1, 1.1, [], {}):
self._assert_not_matches(
condition,
{"value": value},
"incorrect types should not match",
)

def test_exact_event_match_integer(self) -> None:
"""Check that exact_event_match conditions work as expected for integers."""

condition = {
"kind": "com.beeper.msc3758.exact_event_match",
"key": "content.value",
"value": 1,
}
self._assert_matches(
condition,
{"value": 1},
"exact value should match",
)
value: Any
for value in (1.1, -1, 0):
self._assert_not_matches(
condition,
{"value": value},
"incorrect values should not match",
)
for value in ("1", True, False, None, [], {}):
self._assert_not_matches(
condition,
{"value": value},
"incorrect types should not match",
)

def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})


Loading…
Cancel
Save