* Declare new config * Parse new config * Read new config * Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` * Helper to upsert to event fields without exceeding size limits. * Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. * Changelog * Move StateFilter tests should have done this in #14668 * Add extra methods to StateFilter * Use StateFilter * Ensure test file enforces typed defs; alphabetise * Workaround surprising get_current_state_ids * Whoops, fix mypytags/v1.74.0rc1
@@ -0,0 +1 @@ | |||
Allow selecting "prejoin" events by state keys in addition to event types. |
@@ -2501,32 +2501,53 @@ Config settings related to the client/server API | |||
--- | |||
### `room_prejoin_state` | |||
Controls for the state that is shared with users who receive an invite | |||
to a room. By default, the following state event types are shared with users who | |||
receive invites to the room: | |||
- m.room.join_rules | |||
- m.room.canonical_alias | |||
- m.room.avatar | |||
- m.room.encryption | |||
- m.room.name | |||
- m.room.create | |||
- m.room.topic | |||
This setting controls the state that is shared with users upon receiving an | |||
invite to a room, or in reply to a knock on a room. By default, the following | |||
state events are shared with users: | |||
- `m.room.join_rules` | |||
- `m.room.canonical_alias` | |||
- `m.room.avatar` | |||
- `m.room.encryption` | |||
- `m.room.name` | |||
- `m.room.create` | |||
- `m.room.topic` | |||
To change the default behavior, use the following sub-options: | |||
* `disable_default_event_types`: set to true to disable the above defaults. If this | |||
is enabled, only the event types listed in `additional_event_types` are shared. | |||
Defaults to false. | |||
* `additional_event_types`: Additional state event types to share with users when they are invited | |||
to a room. By default, this list is empty (so only the default event types are shared). | |||
* `disable_default_event_types`: boolean. Set to `true` to disable the above | |||
defaults. If this is enabled, only the event types listed in | |||
`additional_event_types` are shared. Defaults to `false`. | |||
* `additional_event_types`: A list of additional state events to include in the | |||
events to be shared. By default, this list is empty (so only the default event | |||
types are shared). | |||
Each entry in this list should be either a single string or a list of two | |||
strings. | |||
* A standalone string `t` represents all events with type `t` (i.e. | |||
with no restrictions on state keys). | |||
* A pair of strings `[t, s]` represents a single event with type `t` and | |||
state key `s`. The same type can appear in two entries with different state | |||
keys: in this situation, both state keys are included in prejoin state. | |||
Example configuration: | |||
```yaml | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
disable_default_event_types: false | |||
additional_event_types: | |||
- org.example.custom.event.type | |||
- m.room.join_rules | |||
# Share all events of type `org.example.custom.event.typeA` | |||
- org.example.custom.event.typeA | |||
# Share only events of type `org.example.custom.event.typeB` whose | |||
# state_key is "foo" | |||
- ["org.example.custom.event.typeB", "foo"] | |||
# Share only events of type `org.example.custom.event.typeC` whose | |||
# state_key is "bar" or "baz" | |||
- ["org.example.custom.event.typeC", "bar"] | |||
- ["org.example.custom.event.typeC", "baz"] | |||
``` | |||
*Changed in Synapse 1.74:* admins can filter the events in prejoin state based | |||
on their state key. | |||
--- | |||
### `track_puppeted_user_ips` | |||
@@ -89,6 +89,12 @@ disallow_untyped_defs = False | |||
[mypy-tests.*] | |||
disallow_untyped_defs = False | |||
[mypy-tests.config.test_api] | |||
disallow_untyped_defs = True | |||
[mypy-tests.federation.transport.test_client] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_sso] | |||
disallow_untyped_defs = True | |||
@@ -101,7 +107,7 @@ disallow_untyped_defs = True | |||
[mypy-tests.push.test_bulk_push_rule_evaluator] | |||
disallow_untyped_defs = True | |||
[mypy-tests.test_server] | |||
[mypy-tests.rest.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.state.test_profile] | |||
@@ -110,10 +116,10 @@ disallow_untyped_defs = True | |||
[mypy-tests.storage.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.rest.*] | |||
[mypy-tests.test_server] | |||
disallow_untyped_defs = True | |||
[mypy-tests.federation.transport.test_client] | |||
[mypy-tests.types.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.util.caches.*] | |||
@@ -33,6 +33,9 @@ def validate_config( | |||
config: the configuration value to be validated | |||
config_path: the path within the config file. This will be used as a basis | |||
for the error message. | |||
Raises: | |||
ConfigError, if validation fails. | |||
""" | |||
try: | |||
jsonschema.validate(config, json_schema) | |||
@@ -13,12 +13,13 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Iterable | |||
from typing import Any, Iterable, Optional, Tuple | |||
from synapse.api.constants import EventTypes | |||
from synapse.config._base import Config, ConfigError | |||
from synapse.config._util import validate_config | |||
from synapse.types import JsonDict | |||
from synapse.types.state import StateFilter | |||
logger = logging.getLogger(__name__) | |||
@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__) | |||
class ApiConfig(Config): | |||
section = "api" | |||
room_prejoin_state: StateFilter | |||
track_puppetted_users_ips: bool | |||
def read_config(self, config: JsonDict, **kwargs: Any) -> None: | |||
validate_config(_MAIN_SCHEMA, config, ()) | |||
self.room_prejoin_state = list(self._get_prejoin_state_types(config)) | |||
self.room_prejoin_state = StateFilter.from_types( | |||
self._get_prejoin_state_entries(config) | |||
) | |||
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) | |||
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: | |||
"""Get the event types to include in the prejoin state | |||
Parses the config and returns an iterable of the event types to be included. | |||
""" | |||
def _get_prejoin_state_entries( | |||
self, config: JsonDict | |||
) -> Iterable[Tuple[str, Optional[str]]]: | |||
"""Get the event types and state keys to include in the prejoin state.""" | |||
room_prejoin_state_config = config.get("room_prejoin_state") or {} | |||
# backwards-compatibility support for room_invite_state_types | |||
@@ -50,33 +55,39 @@ class ApiConfig(Config): | |||
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) | |||
yield from config["room_invite_state_types"] | |||
for event_type in config["room_invite_state_types"]: | |||
yield event_type, None | |||
return | |||
if not room_prejoin_state_config.get("disable_default_event_types"): | |||
yield from _DEFAULT_PREJOIN_STATE_TYPES | |||
yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS | |||
yield from room_prejoin_state_config.get("additional_event_types", []) | |||
for entry in room_prejoin_state_config.get("additional_event_types", []): | |||
if isinstance(entry, str): | |||
yield entry, None | |||
else: | |||
yield entry | |||
_ROOM_INVITE_STATE_TYPES_WARNING = """\ | |||
WARNING: The 'room_invite_state_types' configuration setting is now deprecated, | |||
and replaced with 'room_prejoin_state'. New features may not work correctly | |||
unless 'room_invite_state_types' is removed. See the sample configuration file for | |||
details of 'room_prejoin_state'. | |||
unless 'room_invite_state_types' is removed. See the config documentation at | |||
https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state | |||
for details of 'room_prejoin_state'. | |||
-------------------------------------------------------------------------------- | |||
""" | |||
_DEFAULT_PREJOIN_STATE_TYPES = [ | |||
EventTypes.JoinRules, | |||
EventTypes.CanonicalAlias, | |||
EventTypes.RoomAvatar, | |||
EventTypes.RoomEncryption, | |||
EventTypes.Name, | |||
_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [ | |||
(EventTypes.JoinRules, ""), | |||
(EventTypes.CanonicalAlias, ""), | |||
(EventTypes.RoomAvatar, ""), | |||
(EventTypes.RoomEncryption, ""), | |||
(EventTypes.Name, ""), | |||
# Per MSC1772. | |||
EventTypes.Create, | |||
(EventTypes.Create, ""), | |||
# Per MSC3173. | |||
EventTypes.Topic, | |||
(EventTypes.Topic, ""), | |||
] | |||
@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { | |||
"disable_default_event_types": {"type": "boolean"}, | |||
"additional_event_types": { | |||
"type": "array", | |||
"items": {"type": "string"}, | |||
"items": { | |||
"oneOf": [ | |||
{"type": "string"}, | |||
{ | |||
"type": "array", | |||
"items": {"type": "string"}, | |||
"minItems": 2, | |||
"maxItems": 2, | |||
}, | |||
], | |||
}, | |||
}, | |||
}, | |||
}, | |||
@@ -28,8 +28,14 @@ from typing import ( | |||
) | |||
import attr | |||
from canonicaljson import encode_canonical_json | |||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes | |||
from synapse.api.constants import ( | |||
MAX_PDU_SIZE, | |||
EventContentFields, | |||
EventTypes, | |||
RelationTypes, | |||
) | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.api.room_versions import RoomVersion | |||
from synapse.types import JsonDict | |||
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None: | |||
elif not isinstance(value, (bool, str)) and value is not None: | |||
# Other potential JSON values (bool, None, str) are safe. | |||
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) | |||
def maybe_upsert_event_field( | |||
event: EventBase, container: JsonDict, key: str, value: object | |||
) -> bool: | |||
"""Upsert an event field, but only if this doesn't make the event too large. | |||
Returns true iff the upsert took place. | |||
""" | |||
if key in container: | |||
old_value: object = container[key] | |||
container[key] = value | |||
# NB: here and below, we assume that passing a non-None `time_now` argument to | |||
# get_pdu_json doesn't increase the size of the encoded result. | |||
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE | |||
if not upsert_okay: | |||
container[key] = old_value | |||
else: | |||
container[key] = value | |||
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE | |||
if not upsert_okay: | |||
del container[key] | |||
return upsert_okay |
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version | |||
from synapse.events import EventBase, relation_from_event | |||
from synapse.events.builder import EventBuilder | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.utils import maybe_upsert_event_field | |||
from synapse.events.validator import EventValidator | |||
from synapse.handlers.directory import DirectoryHandler | |||
from synapse.logging import opentracing | |||
@@ -1739,12 +1740,15 @@ class EventCreationHandler: | |||
if event.type == EventTypes.Member: | |||
if event.content["membership"] == Membership.INVITE: | |||
event.unsigned[ | |||
"invite_room_state" | |||
] = await self.store.get_stripped_room_state_from_event_context( | |||
context, | |||
self.room_prejoin_state_types, | |||
membership_user_id=event.sender, | |||
maybe_upsert_event_field( | |||
event, | |||
event.unsigned, | |||
"invite_room_state", | |||
await self.store.get_stripped_room_state_from_event_context( | |||
context, | |||
self.room_prejoin_state_types, | |||
membership_user_id=event.sender, | |||
), | |||
) | |||
invitee = UserID.from_string(event.state_key) | |||
@@ -1762,11 +1766,14 @@ class EventCreationHandler: | |||
event.signatures.update(returned_invite.signatures) | |||
if event.content["membership"] == Membership.KNOCK: | |||
event.unsigned[ | |||
"knock_room_state" | |||
] = await self.store.get_stripped_room_state_from_event_context( | |||
context, | |||
self.room_prejoin_state_types, | |||
maybe_upsert_event_field( | |||
event, | |||
event.unsigned, | |||
"knock_room_state", | |||
await self.store.get_stripped_room_state_from_event_context( | |||
context, | |||
self.room_prejoin_state_types, | |||
), | |||
) | |||
if event.type == EventTypes.Redaction: | |||
@@ -16,11 +16,11 @@ import logging | |||
import threading | |||
import weakref | |||
from enum import Enum, auto | |||
from itertools import chain | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Collection, | |||
Container, | |||
Dict, | |||
Iterable, | |||
List, | |||
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import ( | |||
) | |||
from synapse.storage.util.sequence import build_sequence_generator | |||
from synapse.types import JsonDict, get_domain_from_id | |||
from synapse.types.state import StateFilter | |||
from synapse.util import unwrapFirstError | |||
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
async def get_stripped_room_state_from_event_context( | |||
self, | |||
context: EventContext, | |||
state_types_to_include: Container[str], | |||
state_keys_to_include: StateFilter, | |||
membership_user_id: Optional[str] = None, | |||
) -> List[JsonDict]: | |||
""" | |||
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
Args: | |||
context: The event context to retrieve state of the room from. | |||
state_types_to_include: The type of state events to include. | |||
state_keys_to_include: The state events to include, for each event type. | |||
membership_user_id: An optional user ID to include the stripped membership state | |||
events of. This is useful when generating the stripped state of a room for | |||
invites. We want to send membership events of the inviter, so that the | |||
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore): | |||
Returns: | |||
A list of dictionaries, each representing a stripped state event from the room. | |||
""" | |||
current_state_ids = await context.get_current_state_ids() | |||
if membership_user_id: | |||
types = chain( | |||
state_keys_to_include.to_types(), | |||
[(EventTypes.Member, membership_user_id)], | |||
) | |||
filter = StateFilter.from_types(types) | |||
else: | |||
filter = state_keys_to_include | |||
selected_state_ids = await context.get_current_state_ids(filter) | |||
# We know this event is not an outlier, so this must be | |||
# non-None. | |||
assert current_state_ids is not None | |||
# The state to include | |||
state_to_include_ids = [ | |||
e_id | |||
for k, e_id in current_state_ids.items() | |||
if k[0] in state_types_to_include | |||
or (membership_user_id and k == (EventTypes.Member, membership_user_id)) | |||
] | |||
assert selected_state_ids is not None | |||
# Confusingly, get_current_state_events may return events that are discarded by | |||
# the filter, if they're in context._state_delta_due_to_event. Strip these away. | |||
selected_state_ids = filter.filter_state(selected_state_ids) | |||
state_to_include = await self.get_events(state_to_include_ids) | |||
state_to_include = await self.get_events(selected_state_ids.values()) | |||
return [ | |||
{ | |||
@@ -118,6 +118,15 @@ class StateFilter: | |||
) | |||
) | |||
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: | |||
"""The inverse to `from_types`.""" | |||
for (event_type, state_keys) in self.types.items(): | |||
if state_keys is None: | |||
yield event_type, None | |||
else: | |||
for state_key in state_keys: | |||
yield event_type, state_key | |||
@staticmethod | |||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": | |||
"""Creates a filter that returns all non-member events, plus the member | |||
@@ -343,6 +352,15 @@ class StateFilter: | |||
for s in state_keys | |||
] | |||
def wildcard_types(self) -> List[str]: | |||
"""Returns a list of event types which require us to fetch all state keys. | |||
This will be empty unless `has_wildcards` returns True. | |||
Returns: | |||
A list of event types. | |||
""" | |||
return [t for t, state_keys in self.types.items() if state_keys is None] | |||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: | |||
"""Return the filter split into two: one which assumes it's exclusively | |||
matching against member state, and one which assumes it's matching | |||
@@ -0,0 +1,145 @@ | |||
from unittest import TestCase as StdlibTestCase | |||
import yaml | |||
from synapse.config import ConfigError | |||
from synapse.config.api import ApiConfig | |||
from synapse.types.state import StateFilter | |||
DEFAULT_PREJOIN_STATE_PAIRS = { | |||
("m.room.join_rules", ""), | |||
("m.room.canonical_alias", ""), | |||
("m.room.avatar", ""), | |||
("m.room.encryption", ""), | |||
("m.room.name", ""), | |||
("m.room.create", ""), | |||
("m.room.topic", ""), | |||
} | |||
class TestRoomPrejoinState(StdlibTestCase): | |||
def read_config(self, source: str) -> ApiConfig: | |||
config = ApiConfig() | |||
config.read_config(yaml.safe_load(source)) | |||
return config | |||
def test_no_prejoin_state(self) -> None: | |||
config = self.read_config("foo: bar") | |||
self.assertFalse(config.room_prejoin_state.has_wildcards()) | |||
self.assertEqual( | |||
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS | |||
) | |||
def test_disable_default_event_types(self) -> None: | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
""" | |||
) | |||
self.assertEqual(config.room_prejoin_state, StateFilter.none()) | |||
def test_event_without_state_key(self) -> None: | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
additional_event_types: | |||
- foo | |||
""" | |||
) | |||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) | |||
self.assertEqual(config.room_prejoin_state.concrete_types(), []) | |||
def test_event_with_specific_state_key(self) -> None: | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
additional_event_types: | |||
- [foo, bar] | |||
""" | |||
) | |||
self.assertFalse(config.room_prejoin_state.has_wildcards()) | |||
self.assertEqual( | |||
set(config.room_prejoin_state.concrete_types()), | |||
{("foo", "bar")}, | |||
) | |||
def test_repeated_event_with_specific_state_key(self) -> None: | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
additional_event_types: | |||
- [foo, bar] | |||
- [foo, baz] | |||
""" | |||
) | |||
self.assertFalse(config.room_prejoin_state.has_wildcards()) | |||
self.assertEqual( | |||
set(config.room_prejoin_state.concrete_types()), | |||
{("foo", "bar"), ("foo", "baz")}, | |||
) | |||
def test_no_specific_state_key_overrides_specific_state_key(self) -> None: | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
additional_event_types: | |||
- [foo, bar] | |||
- foo | |||
""" | |||
) | |||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) | |||
self.assertEqual(config.room_prejoin_state.concrete_types(), []) | |||
config = self.read_config( | |||
""" | |||
room_prejoin_state: | |||
disable_default_event_types: true | |||
additional_event_types: | |||
- foo | |||
- [foo, bar] | |||
""" | |||
) | |||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) | |||
self.assertEqual(config.room_prejoin_state.concrete_types(), []) | |||
def test_bad_event_type_entry_raises(self) -> None: | |||
with self.assertRaises(ConfigError): | |||
self.read_config( | |||
""" | |||
room_prejoin_state: | |||
additional_event_types: | |||
- [] | |||
""" | |||
) | |||
with self.assertRaises(ConfigError): | |||
self.read_config( | |||
""" | |||
room_prejoin_state: | |||
additional_event_types: | |||
- [a] | |||
""" | |||
) | |||
with self.assertRaises(ConfigError): | |||
self.read_config( | |||
""" | |||
room_prejoin_state: | |||
additional_event_types: | |||
- [a, b, c] | |||
""" | |||
) | |||
with self.assertRaises(ConfigError): | |||
self.read_config( | |||
""" | |||
room_prejoin_state: | |||
additional_event_types: | |||
- [true, 1.23] | |||
""" | |||
) |
@@ -12,19 +12,20 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import unittest as stdlib_unittest | |||
from synapse.api.constants import EventContentFields | |||
from synapse.api.room_versions import RoomVersions | |||
from synapse.events import make_event_from_dict | |||
from synapse.events.utils import ( | |||
SerializeEventConfig, | |||
copy_and_fixup_power_levels_contents, | |||
maybe_upsert_event_field, | |||
prune_event, | |||
serialize_event, | |||
) | |||
from synapse.util.frozenutils import freeze | |||
from tests import unittest | |||
def MockEvent(**kwargs): | |||
if "event_id" not in kwargs: | |||
@@ -34,7 +35,31 @@ def MockEvent(**kwargs): | |||
return make_event_from_dict(kwargs) | |||
class PruneEventTestCase(unittest.TestCase): | |||
class TestMaybeUpsertEventField(stdlib_unittest.TestCase): | |||
def test_update_okay(self) -> None: | |||
event = make_event_from_dict({"event_id": "$1234"}) | |||
success = maybe_upsert_event_field(event, event.unsigned, "key", "value") | |||
self.assertTrue(success) | |||
self.assertEqual(event.unsigned["key"], "value") | |||
def test_update_not_okay(self) -> None: | |||
event = make_event_from_dict({"event_id": "$1234"}) | |||
LARGE_STRING = "a" * 100_000 | |||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) | |||
self.assertFalse(success) | |||
self.assertNotIn("key", event.unsigned) | |||
def test_update_not_okay_leaves_original_value(self) -> None: | |||
event = make_event_from_dict( | |||
{"event_id": "$1234", "unsigned": {"key": "value"}} | |||
) | |||
LARGE_STRING = "a" * 100_000 | |||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) | |||
self.assertFalse(success) | |||
self.assertEqual(event.unsigned["key"], "value") | |||
class PruneEventTestCase(stdlib_unittest.TestCase): | |||
def run_test(self, evdict, matchdict, **kwargs): | |||
""" | |||
Asserts that a new event constructed with `evdict` will look like | |||
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase): | |||
) | |||
class SerializeEventTestCase(unittest.TestCase): | |||
class SerializeEventTestCase(stdlib_unittest.TestCase): | |||
def serialize(self, ev, fields): | |||
return serialize_event( | |||
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) | |||
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase): | |||
) | |||
class CopyPowerLevelsContentTestCase(unittest.TestCase): | |||
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.test_content = { | |||
"ban": 50, | |||
@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID | |||
from synapse.types.state import StateFilter | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase, TestCase | |||
from tests.unittest import HomeserverTestCase | |||
logger = logging.getLogger(__name__) | |||
@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase): | |||
self.assertEqual(is_all, True) | |||
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) | |||
class StateFilterDifferenceTestCase(TestCase): | |||
def assert_difference( | |||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter | |||
) -> None: | |||
self.assertEqual( | |||
minuend.approx_difference(subtrahend), | |||
expected, | |||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", | |||
) | |||
def test_state_filter_difference_no_include_other_minus_no_include_other( | |||
self, | |||
) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b do not have the | |||
include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze({EventTypes.Create: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:spqr"}}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.CanonicalAlias: {""}}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only a has the include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Create: None, | |||
EventTypes.Member: set(), | |||
EventTypes.CanonicalAlias: set(), | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
# This also shows that the resultant state filter is normalised. | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=True), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter(types=frozendict(), include_others=True), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.CanonicalAlias: {""}, | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b have the include_others | |||
flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=True, | |||
), | |||
StateFilter(types=frozendict(), include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=True), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.Create: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only b has the include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=True, | |||
), | |||
StateFilter(types=frozendict(), include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:spqr"}}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_simple_cases(self) -> None: | |||
""" | |||
Tests some very simple cases of the StateFilter approx_difference, | |||
that are not explicitly tested by the more in-depth tests. | |||
""" | |||
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) | |||
self.assert_difference( | |||
StateFilter.all(), | |||
StateFilter.none(), | |||
StateFilter.all(), | |||
) | |||
class StateFilterTestCase(TestCase): | |||
def test_return_expanded(self) -> None: | |||
""" | |||
Tests the behaviour of the return_expanded() function that expands | |||
StateFilters to include more state types (for the sake of cache hit rate). | |||
""" | |||
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) | |||
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) | |||
# Concrete-only state filters stay the same | |||
# (Case: mixed filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": {""}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# Concrete-only state filters stay the same | |||
# (Case: non-member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{"some.other.state.type": {""}}, include_others=False | |||
).return_expanded(), | |||
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), | |||
) | |||
# Concrete-only state filters stay the same | |||
# (Case: member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# Wildcard member-only state filters stay the same | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# If there is a wildcard in the non-member portion of the filter, | |||
# it's expanded to include ALL non-member events. | |||
# (Case: mixed filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": None, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:test", "@alicia:test"}}, | |||
include_others=True, | |||
), | |||
) | |||
# If there is a wildcard in the non-member portion of the filter, | |||
# it's expanded to include ALL non-member events. | |||
# (Case: non-member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
"some.other.state.type": None, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True), | |||
) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
"some.other.state.type": None, | |||
"yet.another.state.type": {"wombat"}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True), | |||
) |
@@ -0,0 +1,627 @@ | |||
from frozendict import frozendict | |||
from synapse.api.constants import EventTypes | |||
from synapse.types.state import StateFilter | |||
from tests.unittest import TestCase | |||
class StateFilterDifferenceTestCase(TestCase): | |||
def assert_difference( | |||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter | |||
) -> None: | |||
self.assertEqual( | |||
minuend.approx_difference(subtrahend), | |||
expected, | |||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", | |||
) | |||
def test_state_filter_difference_no_include_other_minus_no_include_other( | |||
self, | |||
) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b do not have the | |||
include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze({EventTypes.Create: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:spqr"}}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.CanonicalAlias: {""}}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only a has the include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Create: None, | |||
EventTypes.Member: set(), | |||
EventTypes.CanonicalAlias: set(), | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
# This also shows that the resultant state filter is normalised. | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=True), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter(types=frozendict(), include_others=True), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.CanonicalAlias: {""}, | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b have the include_others | |||
flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=True, | |||
), | |||
StateFilter(types=frozendict(), include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=True), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
EventTypes.Create: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
EventTypes.Create: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only b has the include_others flag set. | |||
""" | |||
# (wildcard on state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.Create: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None}, | |||
include_others=True, | |||
), | |||
StateFilter(types=frozendict(), include_others=False), | |||
) | |||
# (wildcard on state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:spqr"}}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze({EventTypes.Member: None}, include_others=False), | |||
) | |||
# (wildcard on state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (wildcard on state keys): | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=True, | |||
), | |||
StateFilter( | |||
types=frozendict(), | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (specific state keys) | |||
# This one is an over-approximation because we can't represent | |||
# 'all state keys except a few named examples' | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr"}, | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# (specific state keys) - (no state keys) | |||
self.assert_difference( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
EventTypes.CanonicalAlias: {""}, | |||
}, | |||
include_others=False, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: set(), | |||
}, | |||
include_others=True, | |||
), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
def test_state_filter_difference_simple_cases(self) -> None: | |||
""" | |||
Tests some very simple cases of the StateFilter approx_difference, | |||
that are not explicitly tested by the more in-depth tests. | |||
""" | |||
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) | |||
self.assert_difference( | |||
StateFilter.all(), | |||
StateFilter.none(), | |||
StateFilter.all(), | |||
) | |||
class StateFilterTestCase(TestCase): | |||
def test_return_expanded(self) -> None: | |||
""" | |||
Tests the behaviour of the return_expanded() function that expands | |||
StateFilters to include more state types (for the sake of cache hit rate). | |||
""" | |||
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) | |||
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) | |||
# Concrete-only state filters stay the same | |||
# (Case: mixed filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": {""}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": {""}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# Concrete-only state filters stay the same | |||
# (Case: non-member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{"some.other.state.type": {""}}, include_others=False | |||
).return_expanded(), | |||
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), | |||
) | |||
# Concrete-only state filters stay the same | |||
# (Case: member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
}, | |||
include_others=False, | |||
), | |||
) | |||
# Wildcard member-only state filters stay the same | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{EventTypes.Member: None}, | |||
include_others=False, | |||
), | |||
) | |||
# If there is a wildcard in the non-member portion of the filter, | |||
# it's expanded to include ALL non-member events. | |||
# (Case: mixed filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
EventTypes.Member: {"@wombat:test", "@alicia:test"}, | |||
"some.other.state.type": None, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze( | |||
{EventTypes.Member: {"@wombat:test", "@alicia:test"}}, | |||
include_others=True, | |||
), | |||
) | |||
# If there is a wildcard in the non-member portion of the filter, | |||
# it's expanded to include ALL non-member events. | |||
# (Case: non-member-only filter) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
"some.other.state.type": None, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True), | |||
) | |||
self.assertEqual( | |||
StateFilter.freeze( | |||
{ | |||
"some.other.state.type": None, | |||
"yet.another.state.type": {"wombat"}, | |||
}, | |||
include_others=False, | |||
).return_expanded(), | |||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True), | |||
) |