Browse Source

Allow selecting "prejoin" events by state keys (#14642)

* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
tags/v1.74.0rc1
David Robertson 1 year ago
committed by GitHub
parent
commit
e2a1adbf5d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 983 additions and 695 deletions
  1. +1
    -0
      changelog.d/14642.feature
  2. +39
    -18
      docs/usage/configuration/config_documentation.md
  3. +9
    -3
      mypy.ini
  4. +3
    -0
      synapse/config/_util.py
  5. +42
    -21
      synapse/config/api.py
  6. +31
    -1
      synapse/events/utils.py
  7. +18
    -11
      synapse/handlers/message.py
  8. +19
    -14
      synapse/storage/databases/main/events_worker.py
  9. +18
    -0
      synapse/types/state.py
  10. +145
    -0
      tests/config/test_api.py
  11. +30
    -5
      tests/events/test_utils.py
  12. +1
    -622
      tests/storage/test_state.py
  13. +0
    -0
     
  14. +627
    -0
      tests/types/test_state.py

+ 1
- 0
changelog.d/14642.feature View File

@@ -0,0 +1 @@
Allow selecting "prejoin" events by state keys in addition to event types.

+ 39
- 18
docs/usage/configuration/config_documentation.md View File

@@ -2501,32 +2501,53 @@ Config settings related to the client/server API
---
### `room_prejoin_state`

Controls for the state that is shared with users who receive an invite
to a room. By default, the following state event types are shared with users who
receive invites to the room:
- m.room.join_rules
- m.room.canonical_alias
- m.room.avatar
- m.room.encryption
- m.room.name
- m.room.create
- m.room.topic
This setting controls the state that is shared with users upon receiving an
invite to a room, or in reply to a knock on a room. By default, the following
state events are shared with users:

- `m.room.join_rules`
- `m.room.canonical_alias`
- `m.room.avatar`
- `m.room.encryption`
- `m.room.name`
- `m.room.create`
- `m.room.topic`

To change the default behavior, use the following sub-options:
* `disable_default_event_types`: set to true to disable the above defaults. If this
is enabled, only the event types listed in `additional_event_types` are shared.
Defaults to false.
* `additional_event_types`: Additional state event types to share with users when they are invited
to a room. By default, this list is empty (so only the default event types are shared).
* `disable_default_event_types`: boolean. Set to `true` to disable the above
defaults. If this is enabled, only the event types listed in
`additional_event_types` are shared. Defaults to `false`.
* `additional_event_types`: A list of additional state events to include in the
events to be shared. By default, this list is empty (so only the default event
types are shared).

Each entry in this list should be either a single string or a list of two
strings.
* A standalone string `t` represents all events with type `t` (i.e.
with no restrictions on state keys).
* A pair of strings `[t, s]` represents a single event with type `t` and
state key `s`. The same type can appear in two entries with different state
keys: in this situation, both state keys are included in prejoin state.

Example configuration:
```yaml
room_prejoin_state:
disable_default_event_types: true
disable_default_event_types: false
additional_event_types:
- org.example.custom.event.type
- m.room.join_rules
# Share all events of type `org.example.custom.event.typeA`
- org.example.custom.event.typeA
# Share only events of type `org.example.custom.event.typeB` whose
# state_key is "foo"
- ["org.example.custom.event.typeB", "foo"]
# Share only events of type `org.example.custom.event.typeC` whose
# state_key is "bar" or "baz"
- ["org.example.custom.event.typeC", "bar"]
- ["org.example.custom.event.typeC", "baz"]
```

*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
on their state key.

---
### `track_puppeted_user_ips`



+ 9
- 3
mypy.ini View File

@@ -89,6 +89,12 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False

[mypy-tests.config.test_api]
disallow_untyped_defs = True

[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True

@@ -101,7 +107,7 @@ disallow_untyped_defs = True
[mypy-tests.push.test_bulk_push_rule_evaluator]
disallow_untyped_defs = True

[mypy-tests.test_server]
[mypy-tests.rest.*]
disallow_untyped_defs = True

[mypy-tests.state.test_profile]
@@ -110,10 +116,10 @@ disallow_untyped_defs = True
[mypy-tests.storage.*]
disallow_untyped_defs = True

[mypy-tests.rest.*]
[mypy-tests.test_server]
disallow_untyped_defs = True

[mypy-tests.federation.transport.test_client]
[mypy-tests.types.*]
disallow_untyped_defs = True

[mypy-tests.util.caches.*]


+ 3
- 0
synapse/config/_util.py View File

@@ -33,6 +33,9 @@ def validate_config(
config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis
for the error message.

Raises:
ConfigError, if validation fails.
"""
try:
jsonschema.validate(config, json_schema)


+ 42
- 21
synapse/config/api.py View File

@@ -13,12 +13,13 @@
# limitations under the License.

import logging
from typing import Any, Iterable
from typing import Any, Iterable, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config
from synapse.types import JsonDict
from synapse.types.state import StateFilter

logger = logging.getLogger(__name__)

@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"

room_prejoin_state: StateFilter
track_puppetted_users_ips: bool

def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
self.room_prejoin_state = StateFilter.from_types(
self._get_prejoin_state_entries(config)
)
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)

def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
"""Get the event types to include in the prejoin state

Parses the config and returns an iterable of the event types to be included.
"""
def _get_prejoin_state_entries(
self, config: JsonDict
) -> Iterable[Tuple[str, Optional[str]]]:
"""Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {}

# backwards-compatibility support for room_invite_state_types
@@ -50,33 +55,39 @@ class ApiConfig(Config):

logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)

yield from config["room_invite_state_types"]
for event_type in config["room_invite_state_types"]:
yield event_type, None
return

if not room_prejoin_state_config.get("disable_default_event_types"):
yield from _DEFAULT_PREJOIN_STATE_TYPES
yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS

yield from room_prejoin_state_config.get("additional_event_types", [])
for entry in room_prejoin_state_config.get("additional_event_types", []):
if isinstance(entry, str):
yield entry, None
else:
yield entry


_ROOM_INVITE_STATE_TYPES_WARNING = """\
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
and replaced with 'room_prejoin_state'. New features may not work correctly
unless 'room_invite_state_types' is removed. See the sample configuration file for
details of 'room_prejoin_state'.
unless 'room_invite_state_types' is removed. See the config documentation at
https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
for details of 'room_prejoin_state'.
--------------------------------------------------------------------------------
"""

_DEFAULT_PREJOIN_STATE_TYPES = [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.RoomEncryption,
EventTypes.Name,
_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
(EventTypes.JoinRules, ""),
(EventTypes.CanonicalAlias, ""),
(EventTypes.RoomAvatar, ""),
(EventTypes.RoomEncryption, ""),
(EventTypes.Name, ""),
# Per MSC1772.
EventTypes.Create,
(EventTypes.Create, ""),
# Per MSC3173.
EventTypes.Topic,
(EventTypes.Topic, ""),
]


@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
"disable_default_event_types": {"type": "boolean"},
"additional_event_types": {
"type": "array",
"items": {"type": "string"},
"items": {
"oneOf": [
{"type": "string"},
{
"type": "array",
"items": {"type": "string"},
"minItems": 2,
"maxItems": 2,
},
],
},
},
},
},


+ 31
- 1
synapse/events/utils.py View File

@@ -28,8 +28,14 @@ from typing import (
)

import attr
from canonicaljson import encode_canonical_json

from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.constants import (
MAX_PDU_SIZE,
EventContentFields,
EventTypes,
RelationTypes,
)
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
elif not isinstance(value, (bool, str)) and value is not None:
# Other potential JSON values (bool, None, str) are safe.
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)


def maybe_upsert_event_field(
event: EventBase, container: JsonDict, key: str, value: object
) -> bool:
"""Upsert an event field, but only if this doesn't make the event too large.

Returns true iff the upsert took place.
"""
if key in container:
old_value: object = container[key]
container[key] = value
# NB: here and below, we assume that passing a non-None `time_now` argument to
# get_pdu_json doesn't increase the size of the encoded result.
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
container[key] = old_value
else:
container[key] = value
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
del container[key]

return upsert_okay

+ 18
- 11
synapse/handlers/message.py View File

@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -1739,12 +1740,15 @@ class EventCreationHandler:

if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
event.unsigned[
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
membership_user_id=event.sender,
maybe_upsert_event_field(
event,
event.unsigned,
"invite_room_state",
await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
membership_user_id=event.sender,
),
)

invitee = UserID.from_string(event.state_key)
@@ -1762,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)

if event.content["membership"] == Membership.KNOCK:
event.unsigned[
"knock_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
maybe_upsert_event_field(
event,
event.unsigned,
"knock_room_state",
await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
),
)

if event.type == EventTypes.Redaction:


+ 19
- 14
synapse/storage/databases/main/events_worker.py View File

@@ -16,11 +16,11 @@ import logging
import threading
import weakref
from enum import Enum, auto
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
Container,
Dict,
Iterable,
List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
state_types_to_include: Container[str],
state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):

Args:
context: The event context to retrieve state of the room from.
state_types_to_include: The type of state events to include.
state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
current_state_ids = await context.get_current_state_ids()
if membership_user_id:
types = chain(
state_keys_to_include.to_types(),
[(EventTypes.Member, membership_user_id)],
)
filter = StateFilter.from_types(types)
else:
filter = state_keys_to_include
selected_state_ids = await context.get_current_state_ids(filter)

# We know this event is not an outlier, so this must be
# non-None.
assert current_state_ids is not None

# The state to include
state_to_include_ids = [
e_id
for k, e_id in current_state_ids.items()
if k[0] in state_types_to_include
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
]
assert selected_state_ids is not None

# Confusingly, get_current_state_events may return events that are discarded by
# the filter, if they're in context._state_delta_due_to_event. Strip these away.
selected_state_ids = filter.filter_state(selected_state_ids)

state_to_include = await self.get_events(state_to_include_ids)
state_to_include = await self.get_events(selected_state_ids.values())

return [
{


+ 18
- 0
synapse/types/state.py View File

@@ -118,6 +118,15 @@ class StateFilter:
)
)

def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
"""The inverse to `from_types`."""
for (event_type, state_keys) in self.types.items():
if state_keys is None:
yield event_type, None
else:
for state_key in state_keys:
yield event_type, state_key

@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
@@ -343,6 +352,15 @@ class StateFilter:
for s in state_keys
]

def wildcard_types(self) -> List[str]:
"""Returns a list of event types which require us to fetch all state keys.
This will be empty unless `has_wildcards` returns True.

Returns:
A list of event types.
"""
return [t for t, state_keys in self.types.items() if state_keys is None]

def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching


+ 145
- 0
tests/config/test_api.py View File

@@ -0,0 +1,145 @@
from unittest import TestCase as StdlibTestCase

import yaml

from synapse.config import ConfigError
from synapse.config.api import ApiConfig
from synapse.types.state import StateFilter

DEFAULT_PREJOIN_STATE_PAIRS = {
("m.room.join_rules", ""),
("m.room.canonical_alias", ""),
("m.room.avatar", ""),
("m.room.encryption", ""),
("m.room.name", ""),
("m.room.create", ""),
("m.room.topic", ""),
}


class TestRoomPrejoinState(StdlibTestCase):
def read_config(self, source: str) -> ApiConfig:
config = ApiConfig()
config.read_config(yaml.safe_load(source))
return config

def test_no_prejoin_state(self) -> None:
config = self.read_config("foo: bar")
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
)

def test_disable_default_event_types(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
"""
)
self.assertEqual(config.room_prejoin_state, StateFilter.none())

def test_event_without_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

def test_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar")},
)

def test_repeated_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- [foo, baz]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar"), ("foo", "baz")},
)

def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
- [foo, bar]
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])

def test_bad_event_type_entry_raises(self) -> None:
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- []
"""
)

with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a]
"""
)

with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a, b, c]
"""
)

with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [true, 1.23]
"""
)

+ 30
- 5
tests/events/test_utils.py View File

@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest as stdlib_unittest

from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
SerializeEventConfig,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
serialize_event,
)
from synapse.util.frozenutils import freeze

from tests import unittest


def MockEvent(**kwargs):
if "event_id" not in kwargs:
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)


class PruneEventTestCase(unittest.TestCase):
class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
def test_update_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
self.assertTrue(success)
self.assertEqual(event.unsigned["key"], "value")

def test_update_not_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertNotIn("key", event.unsigned)

def test_update_not_okay_leaves_original_value(self) -> None:
event = make_event_from_dict(
{"event_id": "$1234", "unsigned": {"key": "value"}}
)
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertEqual(event.unsigned["key"], "value")


class PruneEventTestCase(stdlib_unittest.TestCase):
def run_test(self, evdict, matchdict, **kwargs):
"""
Asserts that a new event constructed with `evdict` will look like
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
)


class SerializeEventTestCase(unittest.TestCase):
class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
)


class CopyPowerLevelsContentTestCase(unittest.TestCase):
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None:
self.test_content = {
"ban": 50,


+ 1
- 622
tests/storage/test_state.py View File

@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.types.state import StateFilter
from synapse.util import Clock

from tests.unittest import HomeserverTestCase, TestCase
from tests.unittest import HomeserverTestCase

logger = logging.getLogger(__name__)

@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):

self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)


class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)

def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)

# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)

def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)

# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)

def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)

# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)

def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)

# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)

def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""

self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())

self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)


class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""

self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())

self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())

# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)

# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)

# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)

# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)

# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)

+ 0
- 0
View File


+ 627
- 0
tests/types/test_state.py View File

@@ -0,0 +1,627 @@
from frozendict import frozendict

from synapse.api.constants import EventTypes
from synapse.types.state import StateFilter

from tests.unittest import TestCase


class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)

def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)

# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)

def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)

# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)

def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)

# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)

def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)

# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)

# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)

# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)

# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)

def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""

self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())

self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)


class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""

self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())

self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())

# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)

# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)

# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)

# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)

# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)

# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)

Loading…
Cancel
Save