Adds missing type hints to `tests.storage` package and does not allow untyped definitions.tags/v1.74.0rc1
@@ -0,0 +1 @@ | |||
Add missing type hints. |
@@ -88,6 +88,9 @@ disallow_untyped_defs = False | |||
[mypy-tests.*] | |||
disallow_untyped_defs = False | |||
[mypy-tests.handlers.test_sso] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_user_directory] | |||
disallow_untyped_defs = True | |||
@@ -103,16 +106,7 @@ disallow_untyped_defs = True | |||
[mypy-tests.state.test_profile] | |||
disallow_untyped_defs = True | |||
[mypy-tests.storage.test_id_generators] | |||
disallow_untyped_defs = True | |||
[mypy-tests.storage.test_profile] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_sso] | |||
disallow_untyped_defs = True | |||
[mypy-tests.storage.test_user_directory] | |||
[mypy-tests.storage.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.rest.*] | |||
@@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
@cancellable | |||
async def get_e2e_device_keys_for_cs_api( | |||
self, | |||
query_list: List[Tuple[str, Optional[str]]], | |||
query_list: Collection[Tuple[str, Optional[str]]], | |||
include_displaynames: bool = True, | |||
) -> Dict[str, Dict[str, JsonDict]]: | |||
"""Fetch a list of device keys, formatted suitably for the C/S API. | |||
@@ -12,8 +12,12 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.rest import admin | |||
from synapse.rest.client import devices | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
devices.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.user_id = self.register_user("foo", "pass") | |||
def test_background_remove_deleted_devices_from_device_inbox(self): | |||
def test_background_remove_deleted_devices_from_device_inbox(self) -> None: | |||
"""Test that the background task to delete old device_inboxes works properly.""" | |||
# create a valid device | |||
@@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
self.assertEqual(1, len(res)) | |||
self.assertEqual(res[0], "cur_device") | |||
def test_background_remove_hidden_devices_from_device_inbox(self): | |||
def test_background_remove_hidden_devices_from_device_inbox(self) -> None: | |||
"""Test that the background task to delete hidden devices | |||
from device_inboxes works properly.""" | |||
@@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.hs = hs | |||
self.store: EventsWorkerStore = hs.get_datastores().main | |||
@@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): | |||
self.event_ids.append(event.event_id) | |||
def test_simple(self): | |||
def test_simple(self) -> None: | |||
with LoggingContext(name="test") as ctx: | |||
res = self.get_success( | |||
self.store.have_seen_events( | |||
@@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(res, {self.event_ids[0]}) | |||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) | |||
def test_persisting_event_invalidates_cache(self): | |||
def test_persisting_event_invalidates_cache(self) -> None: | |||
""" | |||
Test to make sure that the `have_seen_event` cache | |||
is invalidated after we persist an event and returns | |||
@@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): | |||
# That should result in a single db query to lookup | |||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) | |||
def test_invalidate_cache_by_room_id(self): | |||
def test_invalidate_cache_by_room_id(self) -> None: | |||
""" | |||
Test to make sure that all events associated with the given `(room_id,)` | |||
are invalidated in the `have_seen_event` cache. | |||
@@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store: EventsWorkerStore = hs.get_datastores().main | |||
self.user = self.register_user("user", "pass") | |||
@@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): | |||
# Reset the event cache so the tests start with it empty | |||
self.get_success(self.store._get_event_cache.clear()) | |||
def test_simple(self): | |||
def test_simple(self) -> None: | |||
"""Test that we cache events that we pull from the DB.""" | |||
with LoggingContext("test") as ctx: | |||
@@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): | |||
# We should have fetched the event from the DB | |||
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) | |||
def test_event_ref(self): | |||
def test_event_ref(self) -> None: | |||
"""Test that we reuse events that are still in memory but have fallen | |||
out of the cache, rather than requesting them from the DB. | |||
""" | |||
@@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): | |||
# from the DB | |||
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) | |||
def test_dedupe(self): | |||
def test_dedupe(self) -> None: | |||
"""Test that if we request the same event multiple times we only pull it | |||
out once. | |||
""" | |||
@@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): | |||
class DatabaseOutageTestCase(unittest.HomeserverTestCase): | |||
"""Test event fetching during a database outage.""" | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store: EventsWorkerStore = hs.get_datastores().main | |||
self.room_id = f"!room:{hs.hostname}" | |||
@@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store: EventsWorkerStore = hs.get_datastores().main | |||
self.user = self.register_user("user", "pass") | |||
@@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): | |||
unblock: "Deferred[None]" = Deferred() | |||
original_runWithConnection = self.store.db_pool.runWithConnection | |||
async def runWithConnection(*args, **kwargs): | |||
# Don't bother with the types here, we just pass into the original function. | |||
async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def] | |||
await unblock | |||
return await original_runWithConnection(*args, **kwargs) | |||
@@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) | |||
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) | |||
def test_first_get_event_cancelled(self): | |||
def test_first_get_event_cancelled(self) -> None: | |||
"""Test cancellation of the first `get_event` call sharing a database fetch. | |||
The first `get_event` call is the one which initiates the fetch. We expect the | |||
@@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): | |||
# The second `get_event` call should complete successfully. | |||
self.get_success(get_event2) | |||
def test_second_get_event_cancelled(self): | |||
def test_second_get_event_cancelled(self) -> None: | |||
"""Test cancellation of the second `get_event` call sharing a database fetch.""" | |||
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): | |||
# Cancel the second `get_event` call. | |||
@@ -15,18 +15,20 @@ | |||
from twisted.internet import defer, reactor | |||
from twisted.internet.base import ReactorBase | |||
from twisted.internet.defer import Deferred | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS | |||
from synapse.util import Clock | |||
from tests import unittest | |||
class LockTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
def test_acquire_contention(self): | |||
def test_acquire_contention(self) -> None: | |||
# Track the number of tasks holding the lock. | |||
# Should be at most 1. | |||
in_lock = 0 | |||
@@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
release_lock: "Deferred[None]" = Deferred() | |||
async def task(): | |||
async def task() -> None: | |||
nonlocal in_lock | |||
nonlocal max_in_lock | |||
@@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
# At most one task should have held the lock at a time. | |||
self.assertEqual(max_in_lock, 1) | |||
def test_simple_lock(self): | |||
def test_simple_lock(self) -> None: | |||
"""Test that we can take out a lock and that while we hold it nobody | |||
else can take it out. | |||
""" | |||
@@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
self.get_success(lock3.__aenter__()) | |||
self.get_success(lock3.__aexit__(None, None, None)) | |||
def test_maintain_lock(self): | |||
def test_maintain_lock(self) -> None: | |||
"""Test that we don't time out locks while they're still active""" | |||
lock = self.get_success(self.store.try_acquire_lock("name", "key")) | |||
@@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
self.get_success(lock.__aexit__(None, None, None)) | |||
def test_timeout_lock(self): | |||
def test_timeout_lock(self) -> None: | |||
"""Test that we time out locks if they're not updated for ages""" | |||
lock = self.get_success(self.store.try_acquire_lock("name", "key")) | |||
@@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
self.assertFalse(self.get_success(lock.is_still_valid())) | |||
def test_drop(self): | |||
def test_drop(self) -> None: | |||
"""Test that dropping the context manager means we stop renewing the lock""" | |||
lock = self.get_success(self.store.try_acquire_lock("name", "key")) | |||
@@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase): | |||
lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) | |||
self.assertIsNotNone(lock2) | |||
def test_shutdown(self): | |||
def test_shutdown(self) -> None: | |||
"""Test that shutting down Synapse releases the locks""" | |||
# Acquire two locks | |||
lock = self.get_success(self.store.try_acquire_lock("name", "key1")) | |||
@@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.user_id = self.register_user("foo", "pass") | |||
self.token = self.login("foo", "pass") | |||
@@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
table: str, | |||
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], | |||
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], | |||
): | |||
) -> None: | |||
"""Test that the background update to uniqueify non-thread receipts in | |||
the given receipts table works properly. | |||
@@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
f"Background update did not remove all duplicate receipts from {table}", | |||
) | |||
def test_background_receipts_linearized_unique_index(self): | |||
def test_background_receipts_linearized_unique_index(self) -> None: | |||
"""Test that the background update to uniqueify non-thread receipts in | |||
`receipts_linearized` works properly. | |||
""" | |||
@@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
}, | |||
) | |||
def test_background_receipts_graph_unique_index(self): | |||
def test_background_receipts_graph_unique_index(self) -> None: | |||
"""Test that the background update to uniqueify non-thread receipts in | |||
`receipts_graph` works properly. | |||
""" | |||
@@ -14,10 +14,14 @@ | |||
import json | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import RoomTypes | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.room import _BackgroundUpdates | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.user_id = self.register_user("foo", "pass") | |||
self.token = self.login("foo", "pass") | |||
@@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
return room_id | |||
def test_background_populate_rooms_creator_column(self): | |||
def test_background_populate_rooms_creator_column(self) -> None: | |||
"""Test that the background update to populate the rooms creator column | |||
works properly. | |||
""" | |||
@@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(room_creator_after, self.user_id) | |||
def test_background_add_room_type_column(self): | |||
def test_background_add_room_type_column(self) -> None: | |||
"""Test that the background update to populate the `room_type` column in | |||
`room_stats_state` works properly. | |||
""" | |||
@@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): | |||
{(1, "user1", "hello"), (2, "user2", "bleb")}, | |||
) | |||
def test_simple_update_many(self): | |||
def test_simple_update_many(self) -> None: | |||
""" | |||
simple_update_many performs many updates at once. | |||
""" | |||
@@ -14,13 +14,17 @@ | |||
from typing import Iterable, Optional, Set | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import AccountDataTypes | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests import unittest | |||
class IgnoredUsersTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, hs, reactor, clock): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = self.hs.get_datastores().main | |||
self.user = "@user:test" | |||
@@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): | |||
expected_ignored_user_ids, | |||
) | |||
def test_ignoring_users(self): | |||
def test_ignoring_users(self) -> None: | |||
"""Basic adding/removing of users from the ignore list.""" | |||
self._update_ignore_list("@other:test", "@another:remote") | |||
self.assert_ignored(self.user, {"@other:test", "@another:remote"}) | |||
@@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): | |||
# Check the removed user. | |||
self.assert_ignorers("@another:remote", {self.user}) | |||
def test_caching(self): | |||
def test_caching(self) -> None: | |||
"""Ensure that caching works properly between different users.""" | |||
# The first user ignores a user. | |||
self._update_ignore_list("@other:test") | |||
@@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): | |||
self.assert_ignored(self.user, set()) | |||
self.assert_ignorers("@other:test", {"@second:test"}) | |||
def test_invalid_data(self): | |||
def test_invalid_data(self) -> None: | |||
"""Invalid data ends up clearing out the ignored users list.""" | |||
# Add some data and ensure it is there. | |||
self._update_ignore_list("@other:test") | |||
@@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState | |||
from synapse.config._base import ConfigError | |||
from synapse.events import EventBase | |||
from synapse.server import HomeServer | |||
from synapse.storage.database import DatabasePool, make_conn | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn | |||
from synapse.storage.databases.main.appservice import ( | |||
ApplicationServiceStore, | |||
ApplicationServiceTransactionStore, | |||
@@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable | |||
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
super(ApplicationServiceStoreTestCase, self).setUp() | |||
self.as_yaml_files: List[str] = [] | |||
@@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | |||
super(ApplicationServiceStoreTestCase, self).tearDown() | |||
def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: | |||
def _add_appservice( | |||
self, as_token: str, id: str, url: str, hs_token: str, sender: str | |||
) -> None: | |||
as_yaml = { | |||
"url": url, | |||
"as_token": as_token, | |||
@@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | |||
database, make_conn(db_config, self.engine, "test"), self.hs | |||
) | |||
def _add_service(self, url, as_token, id) -> None: | |||
def _add_service(self, url: str, as_token: str, id: str) -> None: | |||
as_yaml = { | |||
"url": url, | |||
"as_token": as_token, | |||
@@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | |||
outfile.write(yaml.dump(as_yaml)) | |||
self.as_yaml_files.append(as_token) | |||
def _set_state(self, id: str, state: ApplicationServiceState): | |||
def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred: | |||
return self.db_pool.runOperation( | |||
self.engine.convert_param_style( | |||
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)" | |||
@@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | |||
(id, state.value), | |||
) | |||
def _insert_txn(self, as_id, txn_id, events): | |||
def _insert_txn( | |||
self, as_id: str, txn_id: int, events: List[Mock] | |||
) -> "defer.Deferred[None]": | |||
return self.db_pool.runOperation( | |||
self.engine.convert_param_style( | |||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " | |||
@@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): | |||
# required for ApplicationServiceTransactionStoreTestCase tests | |||
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs) -> None: | |||
def __init__( | |||
self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer | |||
) -> None: | |||
super().__init__(database, db_conn, hs) | |||
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): | |||
def _write_config(self, suffix, **kwargs) -> str: | |||
def _write_config(self, suffix: str, **kwargs: str) -> str: | |||
vals = { | |||
"id": "id" + suffix, | |||
"url": "url" + suffix, | |||
@@ -12,8 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from collections import OrderedDict | |||
from typing import Generator | |||
from unittest.mock import Mock | |||
from twisted.internet import defer | |||
@@ -30,7 +30,7 @@ from tests.utils import default_config | |||
class SQLBaseStoreTestCase(unittest.TestCase): | |||
"""Test the "simple" SQL generating methods in SQLBaseStore.""" | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.db_pool = Mock(spec=["runInteraction"]) | |||
self.mock_txn = Mock() | |||
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) | |||
@@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
self.mock_conn.rollback.return_value = None | |||
# Our fake runInteraction just runs synchronously inline | |||
def runInteraction(func, *args, **kwargs): | |||
def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] | |||
return defer.succeed(func(self.mock_txn, *args, **kwargs)) | |||
self.db_pool.runInteraction = runInteraction | |||
def runWithConnection(func, *args, **kwargs): | |||
def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] | |||
return defer.succeed(func(self.mock_conn, *args, **kwargs)) | |||
self.db_pool.runWithConnection = runWithConnection | |||
@@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] | |||
@defer.inlineCallbacks | |||
def test_insert_1col(self): | |||
def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
yield defer.ensureDeferred( | |||
@@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_insert_3cols(self): | |||
def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
yield defer.ensureDeferred( | |||
@@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_select_one_1col(self): | |||
def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) | |||
@@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_select_one_3col(self): | |||
def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
self.mock_txn.fetchone.return_value = (1, 2, 3) | |||
@@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_select_one_missing(self): | |||
def test_select_one_missing( | |||
self, | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 0 | |||
self.mock_txn.fetchone.return_value = None | |||
@@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
self.assertFalse(ret) | |||
@defer.inlineCallbacks | |||
def test_select_list(self): | |||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 3 | |||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) | |||
self.mock_txn.description = (("colA", None, None, None, None, None, None),) | |||
@@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_update_one_1col(self): | |||
def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
yield defer.ensureDeferred( | |||
@@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_update_one_4cols(self): | |||
def test_update_one_4cols( | |||
self, | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
yield defer.ensureDeferred( | |||
@@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
@defer.inlineCallbacks | |||
def test_delete_one(self): | |||
def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 1 | |||
yield defer.ensureDeferred( | |||
@@ -15,11 +15,16 @@ | |||
import os.path | |||
from unittest.mock import Mock, patch | |||
from twisted.test.proto_helpers import MemoryReactor | |||
import synapse.rest.admin | |||
from synapse.api.constants import EventTypes | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.storage import prepare_database | |||
from synapse.storage.types import Cursor | |||
from synapse.types import UserID, create_requester | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
Test the background update to clean forward extremities table. | |||
""" | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.store = homeserver.get_datastores().main | |||
self.room_creator = homeserver.get_room_creation_handler() | |||
@@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) | |||
self.room_id = info["room_id"] | |||
def run_background_update(self): | |||
def run_background_update(self) -> None: | |||
"""Re run the background update to clean up the extremities.""" | |||
# Make sure we don't clash with in progress updates. | |||
self.assertTrue( | |||
@@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
"delete_forward_extremities.sql", | |||
) | |||
def run_delta_file(txn): | |||
def run_delta_file(txn: Cursor) -> None: | |||
prepare_database.executescript(txn, schema_path) | |||
self.get_success( | |||
@@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
(room_id,) | |||
) | |||
def test_soft_failed_extremities_handled_correctly(self): | |||
def test_soft_failed_extremities_handled_correctly(self) -> None: | |||
"""Test that extremities are correctly calculated in the presence of | |||
soft failed events. | |||
@@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
self.assertEqual(latest_event_ids, [event_id_4]) | |||
def test_basic_cleanup(self): | |||
def test_basic_cleanup(self) -> None: | |||
"""Test that extremities are correctly calculated in the presence of | |||
soft failed events. | |||
@@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(latest_event_ids, [event_id_b]) | |||
def test_chain_of_fail_cleanup(self): | |||
def test_chain_of_fail_cleanup(self) -> None: | |||
"""Test that extremities are correctly calculated in the presence of | |||
soft failed events. | |||
@@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(latest_event_ids, [event_id_b]) | |||
def test_forked_graph_cleanup(self): | |||
def test_forked_graph_cleanup(self) -> None: | |||
r"""Test that extremities are correctly calculated in the presence of | |||
soft failed events. | |||
@@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
config = self.default_config() | |||
config["cleanup_extremities_with_dummy_events"] = True | |||
return self.setup_test_homeserver(config=config) | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.store = homeserver.get_datastores().main | |||
self.room_creator = homeserver.get_room_creation_handler() | |||
self.event_creator_handler = homeserver.get_event_creation_handler() | |||
@@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
self.event_creator = homeserver.get_event_creation_handler() | |||
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION | |||
def test_send_dummy_event(self): | |||
def test_send_dummy_event(self) -> None: | |||
self._create_extremity_rich_graph() | |||
# Pump the reactor repeatedly so that the background updates have a | |||
@@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) | |||
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) | |||
def test_send_dummy_events_when_insufficient_power(self): | |||
def test_send_dummy_events_when_insufficient_power(self) -> None: | |||
self._create_extremity_rich_graph() | |||
# Criple power levels | |||
self.helper.send_state( | |||
@@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) | |||
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) | |||
def test_expiry_logic(self): | |||
def test_expiry_logic(self) -> None: | |||
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() | |||
expires old entries correctly. | |||
""" | |||
@@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
0, | |||
) | |||
def _create_extremity_rich_graph(self): | |||
def _create_extremity_rich_graph(self) -> None: | |||
"""Helper method to create bushy graph on demand""" | |||
event_id_start = self.create_and_send_event(self.room_id, self.user) | |||
@@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(len(latest_event_ids), 50) | |||
def _enable_consent_checking(self): | |||
def _enable_consent_checking(self) -> None: | |||
"""Helper method to enable consent checking""" | |||
self.event_creator._block_events_without_consent_error = "No consent from user" | |||
consent_uri_builder = Mock() | |||
@@ -13,15 +13,20 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict | |||
from unittest.mock import Mock | |||
from parameterized import parameterized | |||
from twisted.test.proto_helpers import MemoryReactor | |||
import synapse.rest.admin | |||
from synapse.http.site import XForwardedForRequest | |||
from synapse.rest.client import login | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY | |||
from synapse.types import UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import make_request | |||
@@ -30,14 +35,10 @@ from tests.unittest import override_config | |||
class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
hs = self.setup_test_homeserver() | |||
return hs | |||
def prepare(self, hs, reactor, clock): | |||
self.store = self.hs.get_datastores().main | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
def test_insert_new_client_ip(self): | |||
def test_insert_new_client_ip(self) -> None: | |||
self.reactor.advance(12345678) | |||
user_id = "@user:id" | |||
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
r, | |||
) | |||
def test_insert_new_client_ip_none_device_id(self): | |||
def test_insert_new_client_ip_none_device_id(self) -> None: | |||
""" | |||
An insert with a device ID of NULL will not create a new entry, but | |||
update an existing entry in the user_ips table. | |||
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
@parameterized.expand([(False,), (True,)]) | |||
def test_get_last_client_ip_by_device(self, after_persisting: bool): | |||
def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None: | |||
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data""" | |||
self.reactor.advance(12345678) | |||
@@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
def test_get_last_client_ip_by_device_combined_data(self): | |||
def test_get_last_client_ip_by_device_combined_data(self) -> None: | |||
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted | |||
data together correctly | |||
""" | |||
@@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
@parameterized.expand([(False,), (True,)]) | |||
def test_get_user_ip_and_agents(self, after_persisting: bool): | |||
def test_get_user_ip_and_agents(self, after_persisting: bool) -> None: | |||
"""Test `get_user_ip_and_agents` for persisted and unpersisted data""" | |||
self.reactor.advance(12345678) | |||
@@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
], | |||
) | |||
def test_get_user_ip_and_agents_combined_data(self): | |||
def test_get_user_ip_and_agents_combined_data(self) -> None: | |||
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data | |||
together correctly | |||
""" | |||
@@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) | |||
def test_disabled_monthly_active_user(self): | |||
def test_disabled_monthly_active_user(self) -> None: | |||
user_id = "@user:server" | |||
self.get_success( | |||
self.store.insert_client_ip( | |||
@@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertFalse(active) | |||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) | |||
def test_adding_monthly_active_user_when_full(self): | |||
def test_adding_monthly_active_user_when_full(self) -> None: | |||
lots_of_users = 100 | |||
user_id = "@user:server" | |||
@@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertFalse(active) | |||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) | |||
def test_adding_monthly_active_user_when_space(self): | |||
def test_adding_monthly_active_user_when_space(self) -> None: | |||
user_id = "@user:server" | |||
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) | |||
self.assertFalse(active) | |||
@@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertTrue(active) | |||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) | |||
def test_updating_monthly_active_user_when_space(self): | |||
def test_updating_monthly_active_user_when_space(self) -> None: | |||
user_id = "@user:server" | |||
self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) | |||
@@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) | |||
self.assertTrue(active) | |||
def test_devices_last_seen_bg_update(self): | |||
def test_devices_last_seen_bg_update(self) -> None: | |||
# First make sure we have completed all updates. | |||
self.wait_for_background_updates() | |||
@@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
r, | |||
) | |||
def test_old_user_ips_pruned(self): | |||
def test_old_user_ips_pruned(self) -> None: | |||
# First make sure we have completed all updates. | |||
self.wait_for_background_updates() | |||
@@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(result, []) | |||
# But we should still get the correct values for the device | |||
result = self.get_success( | |||
result2 = self.get_success( | |||
self.store.get_last_client_ip_by_device(user_id, device_id) | |||
) | |||
r = result[(user_id, device_id)] | |||
r = result2[(user_id, device_id)] | |||
self.assertDictContainsSubset( | |||
{ | |||
"user_id": user_id, | |||
@@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def make_homeserver(self, reactor, clock): | |||
hs = self.setup_test_homeserver() | |||
return hs | |||
def prepare(self, hs, reactor, clock): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = self.hs.get_datastores().main | |||
self.user_id = self.register_user("bob", "abc123", True) | |||
def test_request_with_xforwarded(self): | |||
def test_request_with_xforwarded(self) -> None: | |||
""" | |||
The IP in X-Forwarded-For is entered into the client IPs table. | |||
""" | |||
@@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): | |||
{"request": XForwardedForRequest}, | |||
) | |||
def test_request_from_getPeer(self): | |||
def test_request_from_getPeer(self) -> None: | |||
""" | |||
The IP returned by getPeer is entered into the client IPs table, if | |||
there's no X-Forwarded-For header. | |||
""" | |||
self._runtest({}, "127.0.0.1", {}) | |||
def _runtest(self, headers, expected_ip, make_request_args): | |||
def _runtest( | |||
self, | |||
headers: Dict[bytes, bytes], | |||
expected_ip: str, | |||
make_request_args: Dict[str, Any], | |||
) -> None: | |||
device_id = "bleb" | |||
access_token = self.login("bob", "abc123", device_id=device_id) | |||
@@ -31,7 +31,7 @@ from tests import unittest | |||
class TupleComparisonClauseTestCase(unittest.TestCase): | |||
def test_native_tuple_comparison(self): | |||
def test_native_tuple_comparison(self) -> None: | |||
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) | |||
self.assertEqual(clause, "(a,b) > (?,?)") | |||
self.assertEqual(args, [1, 2]) | |||
@@ -12,17 +12,24 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Collection, List, Tuple | |||
from twisted.test.proto_helpers import MemoryReactor | |||
import synapse.api.errors | |||
from synapse.api.constants import EduTypes | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class DeviceStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
def add_device_change(self, user_id, device_ids, host): | |||
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: | |||
"""Add a device list change for the given device to | |||
`device_lists_outbound_pokes` table. | |||
""" | |||
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def test_store_new_device(self): | |||
def test_store_new_device(self) -> None: | |||
self.get_success( | |||
self.store.store_device("user_id", "device_id", "display_name") | |||
) | |||
res = self.get_success(self.store.get_device("user_id", "device_id")) | |||
assert res is not None | |||
self.assertDictContainsSubset( | |||
{ | |||
"user_id": "user_id", | |||
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
res, | |||
) | |||
def test_get_devices_by_user(self): | |||
def test_get_devices_by_user(self) -> None: | |||
self.get_success( | |||
self.store.store_device("user_id", "device1", "display_name 1") | |||
) | |||
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
res["device2"], | |||
) | |||
def test_count_devices_by_users(self): | |||
def test_count_devices_by_users(self) -> None: | |||
self.get_success( | |||
self.store.store_device("user_id", "device1", "display_name 1") | |||
) | |||
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(3, res) | |||
def test_get_device_updates_by_remote(self): | |||
def test_get_device_updates_by_remote(self) -> None: | |||
device_ids = ["device_id1", "device_id2"] | |||
# Add two device updates with sequential `stream_id`s | |||
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
# Check original device_ids are contained within these updates | |||
self._check_devices_in_updates(device_ids, device_updates) | |||
def test_get_device_updates_by_remote_can_limit_properly(self): | |||
def test_get_device_updates_by_remote_can_limit_properly(self) -> None: | |||
""" | |||
Tests that `get_device_updates_by_remote` returns an appropriate | |||
stream_id to resume fetching from (without skipping any results). | |||
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(device_updates, []) | |||
def _check_devices_in_updates(self, expected_device_ids, device_updates): | |||
def _check_devices_in_updates( | |||
self, | |||
expected_device_ids: Collection[str], | |||
device_updates: List[Tuple[str, JsonDict]], | |||
) -> None: | |||
"""Check that an specific device ids exist in a list of device update EDUs""" | |||
self.assertEqual(len(device_updates), len(expected_device_ids)) | |||
@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
} | |||
self.assertEqual(received_device_ids, set(expected_device_ids)) | |||
def test_update_device(self): | |||
def test_update_device(self) -> None: | |||
self.get_success( | |||
self.store.store_device("user_id", "device_id", "display_name 1") | |||
) | |||
res = self.get_success(self.store.get_device("user_id", "device_id")) | |||
assert res is not None | |||
self.assertEqual("display_name 1", res["display_name"]) | |||
# do a no-op first | |||
self.get_success(self.store.update_device("user_id", "device_id")) | |||
res = self.get_success(self.store.get_device("user_id", "device_id")) | |||
assert res is not None | |||
self.assertEqual("display_name 1", res["display_name"]) | |||
# do the update | |||
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase): | |||
# check it worked | |||
res = self.get_success(self.store.get_device("user_id", "device_id")) | |||
assert res is not None | |||
self.assertEqual("display_name 2", res["display_name"]) | |||
def test_update_unknown_device(self): | |||
def test_update_unknown_device(self) -> None: | |||
exc = self.get_failure( | |||
self.store.update_device( | |||
"user_id", "unknown_device_id", new_display_name="display_name 2" | |||
@@ -12,19 +12,23 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.types import RoomAlias, RoomID | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class DirectoryStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.room = RoomID.from_string("!abcde:test") | |||
self.alias = RoomAlias.from_string("#my-room:test") | |||
def test_room_to_alias(self): | |||
def test_room_to_alias(self) -> None: | |||
self.get_success( | |||
self.store.create_room_alias_association( | |||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] | |||
@@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): | |||
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), | |||
) | |||
def test_alias_to_room(self): | |||
def test_alias_to_room(self) -> None: | |||
self.get_success( | |||
self.store.create_room_alias_association( | |||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] | |||
@@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): | |||
(self.get_success(self.store.get_association_from_room_alias(self.alias))), | |||
) | |||
def test_delete_alias(self): | |||
def test_delete_alias(self) -> None: | |||
self.get_success( | |||
self.store.create_room_alias_association( | |||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] | |||
@@ -12,7 +12,11 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.e2e_room_keys import RoomKey | |||
from synapse.util import Clock | |||
from tests import unittest | |||
@@ -26,12 +30,12 @@ room_key: RoomKey = { | |||
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver("server", federation_http_client=None) | |||
self.store = hs.get_datastores().main | |||
return hs | |||
def test_room_keys_version_delete(self): | |||
def test_room_keys_version_delete(self) -> None: | |||
# test that deleting a room key backup deletes the keys | |||
version1 = self.get_success( | |||
self.store.create_e2e_room_keys_version( | |||
@@ -12,14 +12,19 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class EndToEndKeyStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
def test_key_without_device_name(self): | |||
def test_key_without_device_name(self) -> None: | |||
now = 1470174257070 | |||
json = {"key": "value"} | |||
@@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): | |||
dev = res["user"]["device"] | |||
self.assertDictContainsSubset(json, dev) | |||
def test_reupload_key(self): | |||
def test_reupload_key(self) -> None: | |||
now = 1470174257070 | |||
json = {"key": "value"} | |||
@@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): | |||
) | |||
self.assertFalse(changed) | |||
def test_get_key_with_device_name(self): | |||
def test_get_key_with_device_name(self) -> None: | |||
now = 1470174257070 | |||
json = {"key": "value"} | |||
@@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): | |||
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev | |||
) | |||
def test_multiple_devices(self): | |||
def test_multiple_devices(self) -> None: | |||
now = 1470174257070 | |||
self.get_success(self.store.store_device("user1", "device1", None)) | |||
@@ -14,6 +14,7 @@ | |||
from typing import Dict, List, Set, Tuple | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.trial import unittest | |||
from synapse.api.constants import EventTypes | |||
@@ -22,18 +23,22 @@ from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main.events import _LinkMap | |||
from synapse.storage.types import Cursor | |||
from synapse.types import create_requester | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class EventChainStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self._next_stream_ordering = 1 | |||
def test_simple(self): | |||
def test_simple(self) -> None: | |||
"""Test that the example in `docs/auth_chain_difference_algorithm.md` | |||
works. | |||
""" | |||
@@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
), | |||
) | |||
def test_out_of_order_events(self): | |||
def test_out_of_order_events(self) -> None: | |||
"""Test that we handle persisting events that we don't have the full | |||
auth chain for yet (which should only happen for out of band memberships). | |||
""" | |||
@@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
def persist( | |||
self, | |||
events: List[EventBase], | |||
): | |||
) -> None: | |||
"""Persist the given events and check that the links generated match | |||
those given. | |||
""" | |||
@@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
e.internal_metadata.stream_ordering = self._next_stream_ordering | |||
self._next_stream_ordering += 1 | |||
def _persist(txn): | |||
def _persist(txn: LoggingTransaction) -> None: | |||
# We need to persist the events to the events and state_events | |||
# tables. | |||
persist_events_store._store_event_txn( | |||
@@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
class LinkMapTestCase(unittest.TestCase): | |||
def test_simple(self): | |||
def test_simple(self) -> None: | |||
"""Basic tests for the LinkMap.""" | |||
link_map = _LinkMap() | |||
@@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.user_id = self.register_user("foo", "pass") | |||
self.token = self.login("foo", "pass") | |||
@@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
# Delete the chain cover info. | |||
def _delete_tables(txn): | |||
def _delete_tables(txn: Cursor) -> None: | |||
txn.execute("DELETE FROM event_auth_chains") | |||
txn.execute("DELETE FROM event_auth_chain_links") | |||
@@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
return room_id, [state1, state2] | |||
def test_background_update_single_room(self): | |||
def test_background_update_single_room(self) -> None: | |||
"""Test that the background update to calculate auth chains for historic | |||
rooms works correctly. | |||
""" | |||
@@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def test_background_update_multiple_rooms(self): | |||
def test_background_update_multiple_rooms(self) -> None: | |||
"""Test that the background update to calculate auth chains for historic | |||
rooms works correctly. | |||
""" | |||
@@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def test_background_update_single_large_room(self): | |||
def test_background_update_single_large_room(self) -> None: | |||
"""Test that the background update to calculate auth chains for historic | |||
rooms works correctly. | |||
""" | |||
@@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def test_background_update_multiple_large_room(self): | |||
def test_background_update_multiple_large_room(self) -> None: | |||
"""Test that the background update to calculate auth chains for historic | |||
rooms works correctly. | |||
""" | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import datetime | |||
from typing import Dict, List, Tuple, Union | |||
from typing import Dict, List, Tuple, Union, cast | |||
import attr | |||
from parameterized import parameterized | |||
@@ -26,11 +26,12 @@ from synapse.api.room_versions import ( | |||
EventFormatVersions, | |||
RoomVersion, | |||
) | |||
from synapse.events import _EventInternalMetadata | |||
from synapse.events import EventBase, _EventInternalMetadata | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.types import Cursor | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock, json_encoder | |||
@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
def test_get_prev_events_for_room(self): | |||
def test_get_prev_events_for_room(self) -> None: | |||
room_id = "@ROOM:local" | |||
# add a bunch of events and hashes to act as forward extremities | |||
def insert_event(txn, i): | |||
def insert_event(txn: Cursor, i: int) -> None: | |||
event_id = "$event_%i:local" % i | |||
txn.execute( | |||
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
for i in range(0, 10): | |||
self.assertEqual("$event_%i:local" % (19 - i), r[i]) | |||
def test_get_rooms_with_many_extremities(self): | |||
def test_get_rooms_with_many_extremities(self) -> None: | |||
room1 = "#room1" | |||
room2 = "#room2" | |||
room3 = "#room3" | |||
def insert_event(txn, i, room_id): | |||
def insert_event(txn: Cursor, i: int, room_id: str) -> None: | |||
event_id = "$event_%i:local" % i | |||
txn.execute( | |||
( | |||
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# | | | |||
# K J | |||
auth_graph = { | |||
auth_graph: Dict[str, List[str]] = { | |||
"a": ["e"], | |||
"b": ["e"], | |||
"c": ["g", "i"], | |||
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# Mark the room as maybe having a cover index. | |||
def store_room(txn): | |||
def store_room(txn: LoggingTransaction) -> None: | |||
self.store.db_pool.simple_insert_txn( | |||
txn, | |||
"rooms", | |||
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# We rudely fiddle with the appropriate tables directly, as that's much | |||
# easier than constructing events properly. | |||
def insert_event(txn): | |||
def insert_event(txn: LoggingTransaction) -> None: | |||
stream_ordering = 0 | |||
for event_id in auth_graph: | |||
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||
txn, | |||
[ | |||
FakeEvent(event_id, room_id, auth_graph[event_id]) | |||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | |||
for event_id in auth_graph | |||
], | |||
) | |||
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
return room_id | |||
@parameterized.expand([(True,), (False,)]) | |||
def test_auth_chain_ids(self, use_chain_cover_index: bool): | |||
def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None: | |||
room_id = self._setup_auth_chain(use_chain_cover_index) | |||
# a and b have the same auth chain. | |||
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
self.assertCountEqual(auth_chain_ids, ["i", "j"]) | |||
@parameterized.expand([(True,), (False,)]) | |||
def test_auth_difference(self, use_chain_cover_index: bool): | |||
def test_auth_difference(self, use_chain_cover_index: bool) -> None: | |||
room_id = self._setup_auth_chain(use_chain_cover_index) | |||
# Now actually test that various combinations give the right result: | |||
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
) | |||
self.assertSetEqual(difference, set()) | |||
def test_auth_difference_partial_cover(self): | |||
def test_auth_difference_partial_cover(self) -> None: | |||
"""Test that we correctly handle rooms where not all events have a chain | |||
cover calculated. This can happen in some obscure edge cases, including | |||
during the background update that calculates the chain cover for old | |||
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# | | | |||
# K J | |||
auth_graph = { | |||
auth_graph: Dict[str, List[str]] = { | |||
"a": ["e"], | |||
"b": ["e"], | |||
"c": ["g", "i"], | |||
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# We rudely fiddle with the appropriate tables directly, as that's much | |||
# easier than constructing events properly. | |||
def insert_event(txn): | |||
def insert_event(txn: LoggingTransaction) -> None: | |||
# First insert the room and mark it as having a chain cover. | |||
self.store.db_pool.simple_insert_txn( | |||
txn, | |||
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||
txn, | |||
[ | |||
FakeEvent(event_id, room_id, auth_graph[event_id]) | |||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | |||
for event_id in auth_graph | |||
if event_id != "b" | |||
], | |||
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||
txn, | |||
[FakeEvent("b", room_id, auth_graph["b"])], | |||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], | |||
) | |||
self.store.db_pool.simple_update_txn( | |||
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
@parameterized.expand( | |||
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()] | |||
) | |||
def test_prune_inbound_federation_queue(self, room_version: RoomVersion): | |||
def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None: | |||
"""Test that pruning of inbound federation queues work""" | |||
room_id = "some_room_id" | |||
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
stream_ordering += 1 | |||
def populate_db(txn: LoggingTransaction): | |||
def populate_db(txn: LoggingTransaction) -> None: | |||
# Insert the room to satisfy the foreign key constraint of | |||
# `event_failed_pull_attempts` | |||
self.store.db_pool.simple_insert_txn( | |||
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) | |||
def test_get_backfill_points_in_room(self): | |||
def test_get_backfill_points_in_room(self) -> None: | |||
""" | |||
Test to make sure only backfill points that are older and come before | |||
the `current_depth` are returned. | |||
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_backfill_points_in_room_excludes_events_we_have_attempted( | |||
self, | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure that events we have attempted to backfill (and within | |||
backoff timeout duration) do not show up as an event to backfill again. | |||
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( | |||
self, | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure after we fake attempt to backfill event "b3" many times, | |||
we can see retry and see the "b3" again after the backoff timeout duration | |||
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
"5": 7, | |||
} | |||
def populate_db(txn: LoggingTransaction): | |||
def populate_db(txn: LoggingTransaction) -> None: | |||
# Insert the room to satisfy the foreign key constraint of | |||
# `event_failed_pull_attempts` | |||
self.store.db_pool.simple_insert_txn( | |||
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) | |||
def test_get_insertion_event_backward_extremities_in_room(self): | |||
def test_get_insertion_event_backward_extremities_in_room(self) -> None: | |||
""" | |||
Test to make sure only insertion event backward extremities that are | |||
older and come before the `current_depth` are returned. | |||
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( | |||
self, | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure that insertion events we have attempted to backfill | |||
(and within backoff timeout duration) do not show up as an event to | |||
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( | |||
self, | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure after we fake attempt to backfill event | |||
"insertion_eventA" many times, we can see retry and see the | |||
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] | |||
self.assertEqual(backfill_event_ids, ["insertion_eventA"]) | |||
def test_get_event_ids_to_not_pull_from_backoff( | |||
self, | |||
): | |||
def test_get_event_ids_to_not_pull_from_backoff(self) -> None: | |||
""" | |||
Test to make sure only event IDs we should backoff from are returned. | |||
""" | |||
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( | |||
self, | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure no event IDs are returned after the backoff duration has | |||
elapsed. | |||
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
self.assertEqual(event_ids_to_backoff, []) | |||
@attr.s | |||
@attr.s(auto_attribs=True) | |||
class FakeEvent: | |||
event_id = attr.ib() | |||
room_id = attr.ib() | |||
auth_events = attr.ib() | |||
event_id: str | |||
room_id: str | |||
auth_events: List[str] | |||
type = "foo" | |||
state_key = "foo" | |||
internal_metadata = _EventInternalMetadata({}) | |||
def auth_event_ids(self): | |||
def auth_event_ids(self) -> List[str]: | |||
return self.auth_events | |||
def is_state(self): | |||
def is_state(self) -> bool: | |||
return True |
@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase | |||
class ExtremStatisticsTestCase(HomeserverTestCase): | |||
def test_exposed_to_prometheus(self): | |||
def test_exposed_to_prometheus(self) -> None: | |||
""" | |||
Forward extremity counts are exposed via Prometheus. | |||
""" | |||
@@ -12,12 +12,19 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List, Optional | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.api.room_versions import RoomVersions | |||
from synapse.events import EventBase | |||
from synapse.federation.federation_base import event_from_pdu_json | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import StateMap | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.state = self.hs.get_state_handler() | |||
self._persistence = self.hs.get_storage_controllers().persistence | |||
self._state_storage_controller = self.hs.get_storage_controllers().state | |||
@@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check that the current extremities is the remote event. | |||
self.assert_extremities([self.remote_event_1.event_id]) | |||
def persist_event(self, event, state=None): | |||
def persist_event( | |||
self, event: EventBase, state: Optional[StateMap[str]] = None | |||
) -> None: | |||
"""Persist the event, with optional state""" | |||
context = self.get_success( | |||
self.state.compute_event_context( | |||
@@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
def assert_extremities(self, expected_extremities): | |||
def assert_extremities(self, expected_extremities: List[str]) -> None: | |||
"""Assert the current extremities for the room""" | |||
extremities = self.get_success( | |||
self.store.get_prev_events_for_room(self.room_id) | |||
) | |||
self.assertCountEqual(extremities, expected_extremities) | |||
def test_prune_gap(self): | |||
def test_prune_gap(self) -> None: | |||
"""Test that we drop extremities after a gap when we see an event from | |||
the same domain. | |||
""" | |||
@@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check the new extremity is just the new remote event. | |||
self.assert_extremities([remote_event_2.event_id]) | |||
def test_do_not_prune_gap_if_state_different(self): | |||
def test_do_not_prune_gap_if_state_different(self) -> None: | |||
"""Test that we don't prune extremities after a gap if the resolved | |||
state is different. | |||
""" | |||
@@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check that we haven't dropped the old extremity. | |||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) | |||
def test_prune_gap_if_old(self): | |||
def test_prune_gap_if_old(self) -> None: | |||
"""Test that we drop extremities after a gap when the previous extremity | |||
is "old" | |||
""" | |||
@@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check the new extremity is just the new remote event. | |||
self.assert_extremities([remote_event_2.event_id]) | |||
def test_do_not_prune_gap_if_other_server(self): | |||
def test_do_not_prune_gap_if_other_server(self) -> None: | |||
"""Test that we do not drop extremities after a gap when we see an event | |||
from a different domain. | |||
""" | |||
@@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check the new extremity is just the new remote event. | |||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) | |||
def test_prune_gap_if_dummy_remote(self): | |||
def test_prune_gap_if_dummy_remote(self) -> None: | |||
"""Test that we drop extremities after a gap when the previous extremity | |||
is a local dummy event and only points to remote events. | |||
""" | |||
@@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check the new extremity is just the new remote event. | |||
self.assert_extremities([remote_event_2.event_id]) | |||
def test_prune_gap_if_dummy_local(self): | |||
def test_prune_gap_if_dummy_local(self) -> None: | |||
"""Test that we don't drop extremities after a gap when the previous | |||
extremity is a local dummy event and points to local events. | |||
""" | |||
@@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
# Check the new extremity is just the new remote event. | |||
self.assert_extremities([remote_event_2.event_id, local_message_event_id]) | |||
def test_do_not_prune_gap_if_not_dummy(self): | |||
def test_do_not_prune_gap_if_not_dummy(self) -> None: | |||
"""Test that we do not drop extremities after a gap when the previous extremity | |||
is not a dummy event. | |||
""" | |||
@@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.state = self.hs.get_state_handler() | |||
self._persistence = self.hs.get_storage_controllers().persistence | |||
self.store = self.hs.get_datastores().main | |||
def test_remote_user_rooms_cache_invalidated(self): | |||
def test_remote_user_rooms_cache_invalidated(self) -> None: | |||
"""Test that if the server leaves a room the `get_rooms_for_user` cache | |||
is invalidated for remote users. | |||
""" | |||
@@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) | |||
self.assertEqual(set(rooms), set()) | |||
def test_room_remote_user_cache_invalidated(self): | |||
def test_room_remote_user_cache_invalidated(self) -> None: | |||
"""Test that if the server leaves a room the `get_users_in_room` cache | |||
is invalidated for remote users. | |||
""" | |||
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import signedjson.key | |||
import signedjson.types | |||
import unpaddedbase64 | |||
from twisted.internet.defer import Deferred | |||
@@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult | |||
import tests.unittest | |||
def decode_verify_key_base64(key_id: str, key_base64: str): | |||
def decode_verify_key_base64( | |||
key_id: str, key_base64: str | |||
) -> signedjson.types.VerifyKey: | |||
key_bytes = unpaddedbase64.decode_base64(key_base64) | |||
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) | |||
@@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64( | |||
class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_server_verify_keys(self): | |||
def test_get_server_verify_keys(self) -> None: | |||
store = self.hs.get_datastores().main | |||
key_id_1 = "ed25519:key1" | |||
@@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||
# non-existent result gives None | |||
self.assertIsNone(res[("server1", "ed25519:key3")]) | |||
def test_cache(self): | |||
def test_cache(self) -> None: | |||
"""Check that updates correctly invalidate the cache.""" | |||
store = self.hs.get_datastores().main | |||
@@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(FORTY_DAYS) | |||
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) | |||
def test_initialise_reserved_users(self): | |||
def test_initialise_reserved_users(self) -> None: | |||
threepids = self.hs.config.server.mau_limits_reserved_threepids | |||
# register three users, of which two have reserved 3pids, and a third | |||
@@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
active_count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(active_count, 3) | |||
def test_can_insert_and_count_mau(self): | |||
def test_can_insert_and_count_mau(self) -> None: | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(count, 0) | |||
@@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(count, 1) | |||
def test_appservice_user_not_counted_in_mau(self): | |||
def test_appservice_user_not_counted_in_mau(self) -> None: | |||
self.get_success( | |||
self.store.register_user( | |||
user_id="@appservice_user:server", appservice_id="wibble" | |||
@@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(count, 0) | |||
def test_user_last_seen_monthly_active(self): | |||
def test_user_last_seen_monthly_active(self) -> None: | |||
user_id1 = "@user1:server" | |||
user_id2 = "@user2:server" | |||
user_id3 = "@user3:server" | |||
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.assertIsNone(result) | |||
@override_config({"max_mau_value": 5}) | |||
def test_reap_monthly_active_users(self): | |||
def test_reap_monthly_active_users(self) -> None: | |||
initial_users = 10 | |||
for i in range(initial_users): | |||
self.get_success( | |||
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
# Note that below says mau_limit (no s), this is the name of the config | |||
# value, although it gets stored on the config object as mau_limits. | |||
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) | |||
def test_reap_monthly_active_users_reserved_users(self): | |||
def test_reap_monthly_active_users_reserved_users(self) -> None: | |||
"""Tests that reaping correctly handles reaping where reserved users are | |||
present""" | |||
threepids = self.hs.config.server.mau_limits_reserved_threepids | |||
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(count, self.hs.config.server.max_mau_value) | |||
def test_populate_monthly_users_is_guest(self): | |||
def test_populate_monthly_users_is_guest(self) -> None: | |||
# Test that guest users are not added to mau list | |||
user_id = "@user_id:host" | |||
@@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.store.upsert_monthly_active_user.assert_not_called() | |||
def test_populate_monthly_users_should_update(self): | |||
def test_populate_monthly_users_should_update(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] | |||
@@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.store.upsert_monthly_active_user.assert_called_once() | |||
def test_populate_monthly_users_should_not_update(self): | |||
def test_populate_monthly_users_should_not_update(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] | |||
@@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.store.upsert_monthly_active_user.assert_not_called() | |||
def test_get_reserved_real_user_account(self): | |||
def test_get_reserved_real_user_account(self) -> None: | |||
# Test no reserved users, or reserved threepids | |||
users = self.get_success(self.store.get_registered_reserved_users()) | |||
self.assertEqual(len(users), 0) | |||
@@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
users = self.get_success(self.store.get_registered_reserved_users()) | |||
self.assertEqual(len(users), len(threepids)) | |||
def test_support_user_not_add_to_mau_limits(self): | |||
def test_support_user_not_add_to_mau_limits(self) -> None: | |||
support_user_id = "@support:test" | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
@@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
@override_config( | |||
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} | |||
) | |||
def test_track_monthly_users_without_cap(self): | |||
def test_track_monthly_users_without_cap(self) -> None: | |||
count = self.get_success(self.store.get_monthly_active_count()) | |||
self.assertEqual(0, count) | |||
@@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(2, count) | |||
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) | |||
def test_no_users_when_not_tracking(self): | |||
def test_no_users_when_not_tracking(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.get_success(self.store.populate_monthly_active_users("@user:sever")) | |||
self.store.upsert_monthly_active_user.assert_not_called() | |||
def test_get_monthly_active_count_by_service(self): | |||
def test_get_monthly_active_count_by_service(self) -> None: | |||
appservice1_user1 = "@appservice1_user1:example.com" | |||
appservice1_user2 = "@appservice1_user2:example.com" | |||
@@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(result[service2], 1) | |||
self.assertEqual(result[native], 1) | |||
def test_get_monthly_active_users_by_service(self): | |||
def test_get_monthly_active_users_by_service(self) -> None: | |||
# (No users, no filtering) -> empty result | |||
result = self.get_success(self.store.get_monthly_active_users_by_service()) | |||
@@ -12,8 +12,12 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.errors import NotFoundError, SynapseError | |||
from synapse.rest.client import room | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase): | |||
user_id = "@red:server" | |||
servlets = [room.register_servlets] | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver("server", federation_http_client=None) | |||
return hs | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.room_id = self.helper.create_room_as(self.user_id) | |||
self.store = hs.get_datastores().main | |||
self._storage_controllers = self.hs.get_storage_controllers() | |||
def test_purge_history(self): | |||
def test_purge_history(self) -> None: | |||
""" | |||
Purging a room history will delete everything before the topological point. | |||
""" | |||
@@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase): | |||
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) | |||
self.get_success(self.store.get_event(last["event_id"])) | |||
def test_purge_history_wont_delete_extrems(self): | |||
def test_purge_history_wont_delete_extrems(self) -> None: | |||
""" | |||
Purging a room history will delete everything before the topological point. | |||
""" | |||
@@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase): | |||
token = self.get_success( | |||
self.store.get_topological_token_for_event(last["event_id"]) | |||
) | |||
assert token.topological is not None | |||
event = f"t{token.topological + 1}-{token.stream + 1}" | |||
# Purge everything before this topological token | |||
@@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase): | |||
self.get_success(self.store.get_event(third["event_id"])) | |||
self.get_success(self.store.get_event(last["event_id"])) | |||
def test_purge_room(self): | |||
def test_purge_room(self) -> None: | |||
""" | |||
Purging a room will delete everything about it. | |||
""" | |||
@@ -14,8 +14,12 @@ | |||
from typing import Collection, Optional | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import ReceiptTypes | |||
from synapse.server import HomeServer | |||
from synapse.types import UserID, create_requester | |||
from synapse.util import Clock | |||
from tests.test_utils.event_injection import create_event | |||
from tests.unittest import HomeserverTestCase | |||
@@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test" | |||
class ReceiptTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, homeserver) -> None: | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
super().prepare(reactor, clock, homeserver) | |||
self.store = homeserver.get_datastores().main | |||
@@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase): | |||
) | |||
self.assertEqual(res, {}) | |||
res = self.get_last_unthreaded_receipt( | |||
res2 = self.get_last_unthreaded_receipt( | |||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] | |||
) | |||
self.assertEqual(res, None) | |||
self.assertIsNone(res2) | |||
def test_get_receipts_for_user(self) -> None: | |||
# Send some events into the first room | |||
@@ -11,27 +11,35 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List, Optional | |||
from typing import List, Optional, cast | |||
from canonicaljson import json | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.api.room_versions import RoomVersions | |||
from synapse.types import RoomID, UserID | |||
from synapse.events import EventBase, _EventInternalMetadata | |||
from synapse.events.builder import EventBuilder | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, RoomID, UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.utils import create_room | |||
class RedactionTestCase(unittest.HomeserverTestCase): | |||
def default_config(self): | |||
def default_config(self) -> JsonDict: | |||
config = super().default_config() | |||
config["redaction_retention_period"] = "30d" | |||
return config | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self._storage = hs.get_storage_controllers() | |||
storage = hs.get_storage_controllers() | |||
assert storage.persistence is not None | |||
self._persistence = storage.persistence | |||
self.event_builder_factory = hs.get_event_builder_factory() | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.depth = 1 | |||
def inject_room_member( | |||
def inject_room_member( # type: ignore[override] | |||
self, | |||
room, | |||
user, | |||
membership, | |||
replaces_state=None, | |||
extra_content: Optional[dict] = None, | |||
): | |||
room: RoomID, | |||
user: UserID, | |||
membership: str, | |||
extra_content: Optional[JsonDict] = None, | |||
) -> EventBase: | |||
content = {"membership": membership} | |||
content.update(extra_content or {}) | |||
builder = self.event_builder_factory.for_room_version( | |||
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
def inject_message(self, room, user, body): | |||
def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase: | |||
self.depth += 1 | |||
builder = self.event_builder_factory.for_room_version( | |||
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
def inject_redaction(self, room, event_id, user, reason): | |||
def inject_redaction( | |||
self, room: RoomID, event_id: str, user: UserID, reason: str | |||
) -> EventBase: | |||
builder = self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
def test_redact(self): | |||
def test_redact(self) -> None: | |||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) | |||
msg_event = self.inject_message(self.room1, self.u_alice, "t") | |||
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
event.unsigned["redacted_because"], | |||
) | |||
def test_redact_join(self): | |||
def test_redact_join(self) -> None: | |||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) | |||
msg_event = self.inject_room_member( | |||
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
event.unsigned["redacted_because"], | |||
) | |||
def test_circular_redaction(self): | |||
def test_circular_redaction(self) -> None: | |||
redaction_event_id1 = "$redaction1_id:test" | |||
redaction_event_id2 = "$redaction2_id:test" | |||
class EventIdManglingBuilder: | |||
def __init__(self, base_builder, event_id): | |||
def __init__(self, base_builder: EventBuilder, event_id: str): | |||
self._base_builder = base_builder | |||
self._event_id = event_id | |||
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
prev_event_ids: List[str], | |||
auth_event_ids: Optional[List[str]], | |||
depth: Optional[int] = None, | |||
): | |||
) -> EventBase: | |||
built_event = await self._base_builder.build( | |||
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids | |||
) | |||
built_event._event_id = self._event_id | |||
built_event._event_id = self._event_id # type: ignore[attr-defined] | |||
built_event._dict["event_id"] = self._event_id | |||
assert built_event.event_id == self._event_id | |||
return built_event | |||
@property | |||
def room_id(self): | |||
def room_id(self) -> str: | |||
return self._base_builder.room_id | |||
@property | |||
def type(self): | |||
def type(self) -> str: | |||
return self._base_builder.type | |||
@property | |||
def internal_metadata(self): | |||
def internal_metadata(self) -> _EventInternalMetadata: | |||
return self._base_builder.internal_metadata | |||
event_1, context_1 = self.get_success( | |||
self.event_creation_handler.create_new_client_event( | |||
EventIdManglingBuilder( | |||
self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
"type": EventTypes.Redaction, | |||
"sender": self.u_alice.to_string(), | |||
"room_id": self.room1.to_string(), | |||
"content": {"reason": "test"}, | |||
"redacts": redaction_event_id2, | |||
}, | |||
cast( | |||
EventBuilder, | |||
EventIdManglingBuilder( | |||
self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
"type": EventTypes.Redaction, | |||
"sender": self.u_alice.to_string(), | |||
"room_id": self.room1.to_string(), | |||
"content": {"reason": "test"}, | |||
"redacts": redaction_event_id2, | |||
}, | |||
), | |||
redaction_event_id1, | |||
), | |||
redaction_event_id1, | |||
) | |||
) | |||
) | |||
self.get_success(self._storage.persistence.persist_event(event_1, context_1)) | |||
self.get_success(self._persistence.persist_event(event_1, context_1)) | |||
event_2, context_2 = self.get_success( | |||
self.event_creation_handler.create_new_client_event( | |||
EventIdManglingBuilder( | |||
self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
"type": EventTypes.Redaction, | |||
"sender": self.u_alice.to_string(), | |||
"room_id": self.room1.to_string(), | |||
"content": {"reason": "test"}, | |||
"redacts": redaction_event_id1, | |||
}, | |||
cast( | |||
EventBuilder, | |||
EventIdManglingBuilder( | |||
self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
"type": EventTypes.Redaction, | |||
"sender": self.u_alice.to_string(), | |||
"room_id": self.room1.to_string(), | |||
"content": {"reason": "test"}, | |||
"redacts": redaction_event_id1, | |||
}, | |||
), | |||
redaction_event_id2, | |||
), | |||
redaction_event_id2, | |||
) | |||
) | |||
) | |||
self.get_success(self._storage.persistence.persist_event(event_2, context_2)) | |||
self.get_success(self._persistence.persist_event(event_2, context_2)) | |||
# fetch one of the redactions | |||
fetched = self.get_success(self.store.get_event(redaction_event_id1)) | |||
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
fetched.unsigned["redacted_because"].event_id, redaction_event_id2 | |||
) | |||
def test_redact_censor(self): | |||
def test_redact_censor(self) -> None: | |||
"""Test that a redacted event gets censored in the DB after a month""" | |||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) | |||
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.assert_dict({"content": {}}, json.loads(event_json)) | |||
def test_redact_redaction(self): | |||
def test_redact_redaction(self) -> None: | |||
"""Tests that we can redact a redaction and can fetch it again.""" | |||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) | |||
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.store.get_event(first_redact_event.event_id, allow_none=True) | |||
) | |||
def test_store_redacted_redaction(self): | |||
def test_store_redacted_redaction(self) -> None: | |||
"""Tests that we can store a redacted redaction.""" | |||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) | |||
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success( | |||
self._storage.persistence.persist_event(redaction_event, context) | |||
) | |||
self.get_success(self._persistence.persist_event(redaction_event, context)) | |||
# Now lets jump to the future where we have censored the redaction event | |||
# in the DB. | |||
@@ -14,10 +14,15 @@ | |||
from typing import List | |||
from unittest import mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.app.generic_worker import GenericWorkerServer | |||
from synapse.server import HomeServer | |||
from synapse.storage.database import LoggingDatabaseConnection | |||
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database | |||
from synapse.storage.schema import SCHEMA_VERSION | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]: | |||
class WorkerSchemaTests(HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver( | |||
federation_http_client=None, homeserver_to_use=GenericWorkerServer | |||
) | |||
return hs | |||
def default_config(self): | |||
def default_config(self) -> JsonDict: | |||
conf = super().default_config() | |||
# Mark this as a worker app. | |||
@@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase): | |||
return conf | |||
def test_rolling_back(self): | |||
def test_rolling_back(self) -> None: | |||
"""Test that workers can start if the DB is a newer schema version""" | |||
db_pool = self.hs.get_datastores().main.db_pool | |||
@@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase): | |||
prepare_database(db_conn, db_pool.engine, self.hs.config) | |||
def test_not_upgraded_old_schema_version(self): | |||
def test_not_upgraded_old_schema_version(self) -> None: | |||
"""Test that workers don't start if the DB has an older schema version""" | |||
db_pool = self.hs.get_datastores().main.db_pool | |||
db_conn = LoggingDatabaseConnection( | |||
@@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): | |||
with self.assertRaises(PrepareDatabaseException): | |||
prepare_database(db_conn, db_pool.engine, self.hs.config) | |||
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): | |||
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None: | |||
""" | |||
Test that workers don't start if the DB is on the current schema version, | |||
but there are still outstanding delta migrations to run. | |||
@@ -12,14 +12,18 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.room_versions import RoomVersions | |||
from synapse.server import HomeServer | |||
from synapse.types import RoomAlias, RoomID, UserID | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class RoomStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
# We can't test RoomStore on its own without the DirectoryStore, for | |||
# management of the 'room_aliases' table | |||
self.store = hs.get_datastores().main | |||
@@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def test_get_room(self): | |||
def test_get_room(self) -> None: | |||
res = self.get_success(self.store.get_room(self.room.to_string())) | |||
assert res is not None | |||
self.assertDictContainsSubset( | |||
{ | |||
"room_id": self.room.to_string(), | |||
"creator": self.u_creator.to_string(), | |||
"is_public": True, | |||
}, | |||
(self.get_success(self.store.get_room(self.room.to_string()))), | |||
res, | |||
) | |||
def test_get_room_unknown_room(self): | |||
def test_get_room_unknown_room(self) -> None: | |||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) | |||
def test_get_room_with_stats(self): | |||
def test_get_room_with_stats(self) -> None: | |||
res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) | |||
assert res is not None | |||
self.assertDictContainsSubset( | |||
{ | |||
"room_id": self.room.to_string(), | |||
"creator": self.u_creator.to_string(), | |||
"public": True, | |||
}, | |||
(self.get_success(self.store.get_room_with_stats(self.room.to_string()))), | |||
res, | |||
) | |||
def test_get_room_with_stats_unknown_room(self): | |||
def test_get_room_with_stats_unknown_room(self) -> None: | |||
self.assertIsNone( | |||
(self.get_success(self.store.get_room_with_stats("!uknown:test"))), | |||
self.get_success(self.store.get_room_with_stats("!uknown:test")) | |||
) |
@@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def test_null_byte(self): | |||
def test_null_byte(self) -> None: | |||
""" | |||
Postgres/SQLite don't like null bytes going into the search tables. Internally | |||
we replace those with a space. | |||
@@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||
if isinstance(store.database_engine, PostgresEngine): | |||
self.assertIn("alice", result.get("highlights")) | |||
def test_non_string(self): | |||
def test_non_string(self) -> None: | |||
"""Test that non-string `value`s are not inserted into `event_search`. | |||
This is particularly important when using sqlite, since a sqlite column can hold | |||
@@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||
self.assertEqual(f.value.code, 404) | |||
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") | |||
def test_sqlite_non_string_deletion_background_update(self): | |||
def test_sqlite_non_string_deletion_background_update(self) -> None: | |||
"""Test the background update to delete bad rows from `event_search`.""" | |||
store = self.hs.get_datastores().main | |||
@@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase): | |||
"results array length should match count", | |||
) | |||
def test_postgres_web_search_for_phrase(self): | |||
def test_postgres_web_search_for_phrase(self) -> None: | |||
""" | |||
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. | |||
This test is skipped unless the postgres instance supports websearch_to_tsquery. | |||
@@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase): | |||
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) | |||
def test_sqlite_search(self): | |||
def test_sqlite_search(self) -> None: | |||
""" | |||
Test sqlite searching for phrases. | |||
""" | |||
@@ -16,10 +16,15 @@ import logging | |||
from frozendict import frozendict | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.api.room_versions import RoomVersions | |||
from synapse.events import EventBase | |||
from synapse.server import HomeServer | |||
from synapse.storage.state import StateFilter | |||
from synapse.types import RoomID, UserID | |||
from synapse.types import JsonDict, RoomID, StateMap, UserID | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase, TestCase | |||
@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__) | |||
class StateStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage_controllers() | |||
self.state_datastore = self.storage.state.stores.state | |||
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase): | |||
) | |||
) | |||
def inject_state_event(self, room, sender, typ, state_key, content): | |||
def inject_state_event( | |||
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict | |||
) -> EventBase: | |||
builder = self.event_builder_factory.for_room_version( | |||
RoomVersions.V1, | |||
{ | |||
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
assert self.storage.persistence is not None | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
return event | |||
def assertStateMapEqual(self, s1, s2): | |||
def assertStateMapEqual( | |||
self, s1: StateMap[EventBase], s2: StateMap[EventBase] | |||
) -> None: | |||
for t in s1: | |||
# just compare event IDs for simplicity | |||
self.assertEqual(s1[t].event_id, s2[t].event_id) | |||
self.assertEqual(len(s1), len(s2)) | |||
def test_get_state_groups_ids(self): | |||
def test_get_state_groups_ids(self) -> None: | |||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) | |||
e2 = self.inject_state_event( | |||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | |||
) | |||
state_group_map = self.get_success( | |||
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) | |||
self.storage.state.get_state_groups_ids( | |||
self.room.to_string(), [e2.event_id] | |||
) | |||
) | |||
self.assertEqual(len(state_group_map), 1) | |||
state_map = list(state_group_map.values())[0] | |||
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase): | |||
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, | |||
) | |||
def test_get_state_groups(self): | |||
def test_get_state_groups(self) -> None: | |||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) | |||
e2 = self.inject_state_event( | |||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | |||
) | |||
state_group_map = self.get_success( | |||
self.storage.state.get_state_groups(self.room, [e2.event_id]) | |||
self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id]) | |||
) | |||
self.assertEqual(len(state_group_map), 1) | |||
state_list = list(state_group_map.values())[0] | |||
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) | |||
def test_get_state_for_event(self): | |||
def test_get_state_for_event(self) -> None: | |||
# this defaults to a linear DAG as each new injection defaults to whatever | |||
# forward extremities are currently in the DB for this room. | |||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) | |||
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase): | |||
class StateFilterDifferenceTestCase(TestCase): | |||
def assert_difference( | |||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter | |||
): | |||
) -> None: | |||
self.assertEqual( | |||
minuend.approx_difference(subtrahend), | |||
expected, | |||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", | |||
) | |||
def test_state_filter_difference_no_include_other_minus_no_include_other(self): | |||
def test_state_filter_difference_no_include_other_minus_no_include_other( | |||
self, | |||
) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b do not have the | |||
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase): | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_no_include_other(self): | |||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only a has the include_others flag set. | |||
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase): | |||
), | |||
) | |||
def test_state_filter_difference_include_other_minus_include_other(self): | |||
def test_state_filter_difference_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), both a and b have the include_others | |||
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase): | |||
), | |||
) | |||
def test_state_filter_difference_no_include_other_minus_include_other(self): | |||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: | |||
""" | |||
Tests the StateFilter.approx_difference method | |||
where, in a.approx_difference(b), only b has the include_others flag set. | |||
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase): | |||
), | |||
) | |||
def test_state_filter_difference_simple_cases(self): | |||
def test_state_filter_difference_simple_cases(self) -> None: | |||
""" | |||
Tests some very simple cases of the StateFilter approx_difference, | |||
that are not explicitly tested by the more in-depth tests. | |||
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase): | |||
class StateFilterTestCase(TestCase): | |||
def test_return_expanded(self): | |||
def test_return_expanded(self) -> None: | |||
""" | |||
Tests the behaviour of the return_expanded() function that expands | |||
StateFilters to include more state types (for the sake of cache hit rate). | |||
@@ -14,11 +14,15 @@ | |||
from typing import List | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, RelationTypes | |||
from synapse.api.filtering import Filter | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
@@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def default_config(self): | |||
def default_config(self) -> JsonDict: | |||
config = super().default_config() | |||
config["experimental_features"] = {"msc3874_enabled": True} | |||
return config | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.user_id = self.register_user("test", "test") | |||
self.tok = self.login("test", "test") | |||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) | |||
@@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase): | |||
return [ev.event_id for ev in events] | |||
def test_filter_relation_senders(self): | |||
def test_filter_relation_senders(self) -> None: | |||
# Messages which second user reacted to. | |||
filter = {"related_by_senders": [self.second_user_id]} | |||
chunk = self._filter_messages(filter) | |||
@@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase): | |||
chunk = self._filter_messages(filter) | |||
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) | |||
def test_filter_relation_type(self): | |||
def test_filter_relation_type(self) -> None: | |||
# Messages which have annotations. | |||
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} | |||
chunk = self._filter_messages(filter) | |||
@@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase): | |||
chunk = self._filter_messages(filter) | |||
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) | |||
def test_filter_relation_senders_and_type(self): | |||
def test_filter_relation_senders_and_type(self) -> None: | |||
# Messages which second user reacted to. | |||
filter = { | |||
"related_by_senders": [self.second_user_id], | |||
@@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase): | |||
chunk = self._filter_messages(filter) | |||
self.assertEqual(chunk, [self.event_id_1]) | |||
def test_duplicate_relation(self): | |||
def test_duplicate_relation(self) -> None: | |||
"""An event should only be returned once if there are multiple relations to it.""" | |||
self.helper.send_event( | |||
room_id=self.room_id, | |||
@@ -12,17 +12,23 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.transactions import DestinationRetryTimings | |||
from synapse.util import Clock | |||
from synapse.util.retryutils import MAX_RETRY_INTERVAL | |||
from tests.unittest import HomeserverTestCase | |||
class TransactionStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.store = homeserver.get_datastores().main | |||
def test_get_set_transactions(self): | |||
def test_get_set_transactions(self) -> None: | |||
"""Tests that we can successfully get a non-existent entry for | |||
destination retries, as well as testing tht we can set and get | |||
correctly. | |||
@@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase): | |||
r, | |||
) | |||
def test_initial_set_transactions(self): | |||
def test_initial_set_transactions(self) -> None: | |||
"""Tests that we can successfully set the destination retries (there | |||
was a bug around invalidating the cache that broke this) | |||
""" | |||
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) | |||
self.get_success(d) | |||
def test_large_destination_retry(self): | |||
def test_large_destination_retry(self) -> None: | |||
d = self.store.set_destination_retry_timings( | |||
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL | |||
) | |||
self.get_success(d) | |||
d = self.store.get_destination_retry_timings("example.com") | |||
self.get_success(d) | |||
d2 = self.store.get_destination_retry_timings("example.com") | |||
self.get_success(d2) |
@@ -12,21 +12,27 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
from synapse.storage.types import Cursor | |||
from synapse.util import Clock | |||
from tests import unittest | |||
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): | |||
"""Test SQL transaction limit doesn't break transactions.""" | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
return self.setup_test_homeserver(db_txn_limit=1000) | |||
def test_config(self): | |||
def test_config(self) -> None: | |||
db_config = self.hs.config.database.get_single_database() | |||
self.assertEqual(db_config.config["txn_limit"], 1000) | |||
def test_select(self): | |||
def do_select(txn): | |||
def test_select(self) -> None: | |||
def do_select(txn: Cursor) -> None: | |||
txn.execute("SELECT 1") | |||
db_pool = self.hs.get_datastores().databases[0] | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Dict | |||
from typing import Collection, Dict | |||
from unittest import mock | |||
from twisted.internet.defer import CancelledError, ensureDeferred | |||
@@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
# the results to be returned by the mocked get_partial_state_events | |||
self._events_dict: Dict[str, bool] = {} | |||
async def get_partial_state_events(events): | |||
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: | |||
return {e: self._events_dict[e] for e in events} | |||
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"]) | |||
@@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
self.tracker = PartialStateEventsTracker(self.mock_store) | |||
def test_does_not_block_for_full_state_events(self): | |||
def test_does_not_block_for_full_state_events(self) -> None: | |||
self._events_dict = {"event1": False, "event2": False} | |||
self.successResultOf( | |||
@@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
["event1", "event2"] | |||
) | |||
def test_blocks_for_partial_state_events(self): | |||
def test_blocks_for_partial_state_events(self) -> None: | |||
self._events_dict = {"event1": True, "event2": False} | |||
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) | |||
@@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
self.tracker.notify_un_partial_stated("event1") | |||
self.successResultOf(d) | |||
def test_un_partial_state_race(self): | |||
def test_un_partial_state_race(self) -> None: | |||
# if the event is un-partial-stated between the initial check and the | |||
# registration of the listener, it should not block. | |||
self._events_dict = {"event1": True, "event2": False} | |||
async def get_partial_state_events(events): | |||
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: | |||
res = {e: self._events_dict[e] for e in events} | |||
# change the result for next time | |||
self._events_dict = {"event1": False, "event2": False} | |||
@@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) | |||
) | |||
def test_un_partial_state_during_get_partial_state_events(self): | |||
def test_un_partial_state_during_get_partial_state_events(self) -> None: | |||
# we should correctly handle a call to notify_un_partial_stated during the | |||
# second call to get_partial_state_events. | |||
self._events_dict = {"event1": True, "event2": False} | |||
async def get_partial_state_events1(events): | |||
async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]: | |||
self.mock_store.get_partial_state_events.side_effect = ( | |||
get_partial_state_events2 | |||
) | |||
return {e: self._events_dict[e] for e in events} | |||
async def get_partial_state_events2(events): | |||
async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]: | |||
self.tracker.notify_un_partial_stated("event1") | |||
self._events_dict["event1"] = False | |||
return {e: self._events_dict[e] for e in events} | |||
@@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) | |||
) | |||
def test_cancellation(self): | |||
def test_cancellation(self) -> None: | |||
self._events_dict = {"event1": True, "event2": False} | |||
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) | |||
@@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase): | |||
self.tracker = PartialCurrentStateTracker(self.mock_store) | |||
def test_does_not_block_for_full_state_rooms(self): | |||
def test_does_not_block_for_full_state_rooms(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(False) | |||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) | |||
def test_blocks_for_partial_room_state(self): | |||
def test_blocks_for_partial_room_state(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) | |||
d = ensureDeferred(self.tracker.await_full_state("room_id")) | |||
@@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase): | |||
self.tracker.notify_un_partial_stated("room_id") | |||
self.successResultOf(d) | |||
def test_un_partial_state_race(self): | |||
def test_un_partial_state_race(self) -> None: | |||
# We should correctly handle race between awaiting the state and us | |||
# un-partialling the state | |||
async def is_partial_state_room(events): | |||
async def is_partial_state_room(room_id: str) -> bool: | |||
self.tracker.notify_un_partial_stated("room_id") | |||
return True | |||
@@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase): | |||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) | |||
def test_cancellation(self): | |||
def test_cancellation(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) | |||
d1 = ensureDeferred(self.tracker.await_full_state("room_id")) | |||