Browse Source

Require types in tests.storage. (#14646)

Adds missing type hints to `tests.storage` package
and does not allow untyped definitions.
tags/v1.74.0rc1
Patrick Cloke 1 year ago
committed by GitHub
parent
commit
3ac412b4e2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 489 additions and 341 deletions
  1. +1
    -0
      changelog.d/14646.misc
  2. +4
    -10
      mypy.ini
  3. +1
    -1
      synapse/storage/databases/main/end_to_end_keys.py
  4. +7
    -3
      tests/storage/databases/main/test_deviceinbox.py
  5. +14
    -13
      tests/storage/databases/main/test_events_worker.py
  6. +10
    -8
      tests/storage/databases/main/test_lock.py
  7. +4
    -4
      tests/storage/databases/main/test_receipts.py
  8. +7
    -3
      tests/storage/databases/main/test_room.py
  9. +1
    -1
      tests/storage/test__base.py
  10. +8
    -4
      tests/storage/test_account_data.py
  11. +14
    -8
      tests/storage/test_appservice.py
  12. +17
    -13
      tests/storage/test_base.py
  13. +23
    -14
      tests/storage/test_cleanup_extrems.py
  14. +30
    -28
      tests/storage/test_client_ips.py
  15. +1
    -1
      tests/storage/test_database.py
  16. +25
    -10
      tests/storage/test_devices.py
  17. +8
    -4
      tests/storage/test_directory.py
  18. +6
    -2
      tests/storage/test_e2e_room_keys.py
  19. +10
    -5
      tests/storage/test_end_to_end_keys.py
  20. +17
    -12
      tests/storage/test_event_chain.py
  21. +35
    -36
      tests/storage/test_event_federation.py
  22. +1
    -1
      tests/storage/test_event_metrics.py
  23. +26
    -13
      tests/storage/test_events.py
  24. +6
    -3
      tests/storage/test_keys.py
  25. +15
    -15
      tests/storage/test_monthly_active_users.py
  26. +10
    -5
      tests/storage/test_purge.py
  27. +9
    -3
      tests/storage/test_receipts.py
  28. +69
    -56
      tests/storage/test_redaction.py
  29. +10
    -5
      tests/storage/test_rollback_worker.py
  30. +16
    -8
      tests/storage/test_room.py
  31. +5
    -5
      tests/storage/test_room_search.py
  32. +30
    -16
      tests/storage/test_state.py
  33. +12
    -6
      tests/storage/test_stream.py
  34. +12
    -6
      tests/storage/test_transactions.py
  35. +10
    -4
      tests/storage/test_txn_limit.py
  36. +15
    -15
      tests/storage/util/test_partial_state_events_tracker.py

+ 1
- 0
changelog.d/14646.misc View File

@@ -0,0 +1 @@
Add missing type hints.

+ 4
- 10
mypy.ini View File

@@ -88,6 +88,9 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False

[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True

[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True

@@ -103,16 +106,7 @@ disallow_untyped_defs = True
[mypy-tests.state.test_profile]
disallow_untyped_defs = True

[mypy-tests.storage.test_id_generators]
disallow_untyped_defs = True

[mypy-tests.storage.test_profile]
disallow_untyped_defs = True

[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True

[mypy-tests.storage.test_user_directory]
[mypy-tests.storage.*]
disallow_untyped_defs = True

[mypy-tests.rest.*]


+ 1
- 1
synapse/storage/databases/main/end_to_end_keys.py View File

@@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_device_keys_for_cs_api(
self,
query_list: List[Tuple[str, Optional[str]]],
query_list: Collection[Tuple[str, Optional[str]]],
include_displaynames: bool = True,
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.


+ 7
- 3
tests/storage/databases/main/test_deviceinbox.py View File

@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest import admin
from synapse.rest.client import devices
from synapse.server import HomeServer
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
devices.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")

def test_background_remove_deleted_devices_from_device_inbox(self):
def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete old device_inboxes works properly."""

# create a valid device
@@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(1, len(res))
self.assertEqual(res[0], "cur_device")

def test_background_remove_hidden_devices_from_device_inbox(self):
def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete hidden devices
from device_inboxes works properly."""



+ 14
- 13
tests/storage/databases/main/test_events_worker.py View File

@@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main

@@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):

self.event_ids.append(event.event_id)

def test_simple(self):
def test_simple(self) -> None:
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(
@@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)

def test_persisting_event_invalidates_cache(self):
def test_persisting_event_invalidates_cache(self) -> None:
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
@@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)

def test_invalidate_cache_by_room_id(self):
def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
@@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main

self.user = self.register_user("user", "pass")
@@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache so the tests start with it empty
self.get_success(self.store._get_event_cache.clear())

def test_simple(self):
def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB."""

with LoggingContext("test") as ctx:
@@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)

def test_event_ref(self):
def test_event_ref(self) -> None:
"""Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB.
"""
@@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)

def test_dedupe(self):
def test_dedupe(self) -> None:
"""Test that if we request the same event multiple times we only pull it
out once.
"""
@@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main

self.room_id = f"!room:{hs.hostname}"
@@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main

self.user = self.register_user("user", "pass")
@@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection

async def runWithConnection(*args, **kwargs):
# Don't bother with the types here, we just pass into the original function.
async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
await unblock
return await original_runWithConnection(*args, **kwargs)

@@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)

def test_first_get_event_cancelled(self):
def test_first_get_event_cancelled(self) -> None:
"""Test cancellation of the first `get_event` call sharing a database fetch.

The first `get_event` call is the one which initiates the fetch. We expect the
@@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
# The second `get_event` call should complete successfully.
self.get_success(get_event2)

def test_second_get_event_cancelled(self):
def test_second_get_event_cancelled(self) -> None:
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.


+ 10
- 8
tests/storage/databases/main/test_lock.py View File

@@ -15,18 +15,20 @@
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
from synapse.util import Clock

from tests import unittest


class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def test_acquire_contention(self):
def test_acquire_contention(self) -> None:
# Track the number of tasks holding the lock.
# Should be at most 1.
in_lock = 0
@@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):

release_lock: "Deferred[None]" = Deferred()

async def task():
async def task() -> None:
nonlocal in_lock
nonlocal max_in_lock

@@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1)

def test_simple_lock(self):
def test_simple_lock(self) -> None:
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
"""
@@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))

def test_maintain_lock(self):
def test_maintain_lock(self) -> None:
"""Test that we don't time out locks while they're still active"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):

self.get_success(lock.__aexit__(None, None, None))

def test_timeout_lock(self):
def test_timeout_lock(self) -> None:
"""Test that we time out locks if they're not updated for ages"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):

self.assertFalse(self.get_success(lock.is_still_valid()))

def test_drop(self):
def test_drop(self) -> None:
"""Test that dropping the context manager means we stop renewing the lock"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock2)

def test_shutdown(self):
def test_shutdown(self) -> None:
"""Test that shutting down Synapse releases the locks"""
# Acquire two locks
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))


+ 4
- 4
tests/storage/databases/main/test_receipts.py View File

@@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
table: str,
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
):
) -> None:
"""Test that the background update to uniqueify non-thread receipts in
the given receipts table works properly.

@@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
f"Background update did not remove all duplicate receipts from {table}",
)

def test_background_receipts_linearized_unique_index(self):
def test_background_receipts_linearized_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_linearized` works properly.
"""
@@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
},
)

def test_background_receipts_graph_unique_index(self):
def test_background_receipts_graph_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_graph` works properly.
"""


+ 7
- 3
tests/storage/databases/main/test_room.py View File

@@ -14,10 +14,14 @@

import json

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.room import _BackgroundUpdates
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):

return room_id

def test_background_populate_rooms_creator_column(self):
def test_background_populate_rooms_creator_column(self) -> None:
"""Test that the background update to populate the rooms creator column
works properly.
"""
@@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(room_creator_after, self.user_id)

def test_background_add_room_type_column(self):
def test_background_add_room_type_column(self) -> None:
"""Test that the background update to populate the `room_type` column in
`room_stats_state` works properly.
"""


+ 1
- 1
tests/storage/test__base.py View File

@@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")},
)

def test_simple_update_many(self):
def test_simple_update_many(self) -> None:
"""
simple_update_many performs many updates at once.
"""


+ 8
- 4
tests/storage/test_account_data.py View File

@@ -14,13 +14,17 @@

from typing import Iterable, Optional, Set

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import AccountDataTypes
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest


class IgnoredUsersTestCase(unittest.HomeserverTestCase):
def prepare(self, hs, reactor, clock):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user = "@user:test"

@@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignored_user_ids,
)

def test_ignoring_users(self):
def test_ignoring_users(self) -> None:
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
@@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user.
self.assert_ignorers("@another:remote", {self.user})

def test_caching(self):
def test_caching(self) -> None:
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
@@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})

def test_invalid_data(self):
def test_invalid_data(self) -> None:
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")


+ 14
- 8
tests/storage/test_appservice.py View File

@@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
from synapse.events import EventBase
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
@@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable


class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self):
def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp()

self.as_yaml_files: List[str] = []
@@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):

super(ApplicationServiceStoreTestCase, self).tearDown()

def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str
) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
database, make_conn(db_config, self.engine, "test"), self.hs
)

def _add_service(self, url, as_token, id) -> None:
def _add_service(self, url: str, as_token: str, id: str) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)

def _set_state(self, id: str, state: ApplicationServiceState):
def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
@@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(id, state.value),
)

def _insert_txn(self, as_id, txn_id, events):
def _insert_txn(
self, as_id: str, txn_id: int, events: List[Mock]
) -> "defer.Deferred[None]":
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
@@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):

# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs) -> None:
def __init__(
self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
) -> None:
super().__init__(database, db_conn, hs)


class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
def _write_config(self, suffix, **kwargs) -> str:
def _write_config(self, suffix: str, **kwargs: str) -> str:
vals = {
"id": "id" + suffix,
"url": "url" + suffix,


+ 17
- 13
tests/storage/test_base.py View File

@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict
from typing import Generator
from unittest.mock import Mock

from twisted.internet import defer
@@ -30,7 +30,7 @@ from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore."""

def setUp(self):
def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
@@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline

def runInteraction(func, *args, **kwargs):
def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs))

self.db_pool.runInteraction = runInteraction

def runWithConnection(func, *args, **kwargs):
def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs))

self.db_pool.runWithConnection = runWithConnection
@@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]

@defer.inlineCallbacks
def test_insert_1col(self):
def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1

yield defer.ensureDeferred(
@@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_insert_3cols(self):
def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1

yield defer.ensureDeferred(
@@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_select_one_1col(self):
def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))

@@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_select_one_3col(self):
def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)

@@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_select_one_missing(self):
def test_select_one_missing(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None

@@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertFalse(ret)

@defer.inlineCallbacks
def test_select_list(self):
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
@@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_update_one_1col(self):
def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1

yield defer.ensureDeferred(
@@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_update_one_4cols(self):
def test_update_one_4cols(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1

yield defer.ensureDeferred(
@@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)

@defer.inlineCallbacks
def test_delete_one(self):
def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1

yield defer.ensureDeferred(


+ 23
- 14
tests/storage/test_cleanup_extrems.py View File

@@ -15,11 +15,16 @@
import os.path
from unittest.mock import Mock, patch

from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage import prepare_database
from synapse.storage.types import Cursor
from synapse.types import UserID, create_requester
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
Test the background update to clean forward extremities table.
"""

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()

@@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]

def run_background_update(self):
def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
@@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"delete_forward_extremities.sql",
)

def run_delta_file(txn):
def run_delta_file(txn: Cursor) -> None:
prepare_database.executescript(txn, schema_path)

self.get_success(
@@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
(room_id,)
)

def test_soft_failed_extremities_handled_correctly(self):
def test_soft_failed_extremities_handled_correctly(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.

@@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):

self.assertEqual(latest_event_ids, [event_id_4])

def test_basic_cleanup(self):
def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.

@@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])

def test_chain_of_fail_cleanup(self):
def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.

@@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])

def test_forked_graph_cleanup(self):
def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
soft failed events.

@@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
return self.setup_test_homeserver(config=config)

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
self.event_creator_handler = homeserver.get_event_creation_handler()
@@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION

def test_send_dummy_event(self):
def test_send_dummy_event(self) -> None:
self._create_extremity_rich_graph()

# Pump the reactor repeatedly so that the background updates have a
@@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))

@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
def test_send_dummy_events_when_insufficient_power(self):
def test_send_dummy_events_when_insufficient_power(self) -> None:
self._create_extremity_rich_graph()
# Criple power levels
self.helper.send_state(
@@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))

@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self):
def test_expiry_logic(self) -> None:
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
@@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
0,
)

def _create_extremity_rich_graph(self):
def _create_extremity_rich_graph(self) -> None:
"""Helper method to create bushy graph on demand"""

event_id_start = self.create_and_send_event(self.room_id, self.user)
@@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertEqual(len(latest_event_ids), 50)

def _enable_consent_checking(self):
def _enable_consent_checking(self) -> None:
"""Helper method to enable consent checking"""
self.event_creator._block_events_without_consent_error = "No consent from user"
consent_uri_builder = Mock()


+ 30
- 28
tests/storage/test_client_ips.py View File

@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict
from unittest.mock import Mock

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
from synapse.util import Clock

from tests import unittest
from tests.server import make_request
@@ -30,14 +35,10 @@ from tests.unittest import override_config


class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs

def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastores().main
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def test_insert_new_client_ip(self):
def test_insert_new_client_ip(self) -> None:
self.reactor.advance(12345678)

user_id = "@user:id"
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)

def test_insert_new_client_ip_none_device_id(self):
def test_insert_new_client_ip_none_device_id(self) -> None:
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)

@parameterized.expand([(False,), (True,)])
def test_get_last_client_ip_by_device(self, after_persisting: bool):
def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)

@@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)

def test_get_last_client_ip_by_device_combined_data(self):
def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly
"""
@@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)

@parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)

@@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)

def test_get_user_ip_and_agents_combined_data(self):
def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly
"""
@@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)

@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)

@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100
user_id = "@user:server"

@@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)

@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active)

@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))

@@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)

def test_devices_last_seen_bg_update(self):
def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()

@@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)

def test_old_user_ips_pruned(self):
def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()

@@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, [])

# But we should still get the correct values for the device
result = self.get_success(
result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)

r = result[(user_id, device_id)]
r = result2[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
@@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs

def prepare(self, hs, reactor, clock):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True)

def test_request_with_xforwarded(self):
def test_request_with_xforwarded(self) -> None:
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
@@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest},
)

def test_request_from_getPeer(self):
def test_request_from_getPeer(self) -> None:
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})

def _runtest(self, headers, expected_ip, make_request_args):
def _runtest(
self,
headers: Dict[bytes, bytes],
expected_ip: str,
make_request_args: Dict[str, Any],
) -> None:
device_id = "bleb"

access_token = self.login("bob", "abc123", device_id=device_id)


+ 1
- 1
tests/storage/test_database.py View File

@@ -31,7 +31,7 @@ from tests import unittest


class TupleComparisonClauseTestCase(unittest.TestCase):
def test_native_tuple_comparison(self):
def test_native_tuple_comparison(self) -> None:
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])


+ 25
- 10
tests/storage/test_devices.py View File

@@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, List, Tuple

from twisted.test.proto_helpers import MemoryReactor

import synapse.api.errors
from synapse.api.constants import EduTypes
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def add_device_change(self, user_id, device_ids, host):
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
)

def test_store_new_device(self):
def test_store_new_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)

res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res,
)

def test_get_devices_by_user(self):
def test_get_devices_by_user(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"],
)

def test_count_devices_by_users(self):
def test_count_devices_by_users(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(3, res)

def test_get_device_updates_by_remote(self):
def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"]

# Add two device updates with sequential `stream_id`s
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

def test_get_device_updates_by_remote_can_limit_properly(self):
def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
"""
Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results).
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(device_updates, [])

def _check_devices_in_updates(self, expected_device_ids, device_updates):
def _check_devices_in_updates(
self,
expected_device_ids: Collection[str],
device_updates: List[Tuple[str, JsonDict]],
) -> None:
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))

@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))

def test_update_device(self):
def test_update_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)

res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 1", res["display_name"])

# do a no-op first
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 1", res["display_name"])

# do the update
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):

# check it worked
res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 2", res["display_name"])

def test_update_unknown_device(self):
def test_update_unknown_device(self) -> None:
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"


+ 8
- 4
tests/storage/test_directory.py View File

@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class DirectoryStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")

def test_room_to_alias(self):
def test_room_to_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)

def test_alias_to_room(self):
def test_alias_to_room(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
)

def test_delete_alias(self):
def test_delete_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]


+ 6
- 2
tests/storage/test_e2e_room_keys.py View File

@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.util import Clock

from tests import unittest

@@ -26,12 +30,12 @@ room_key: RoomKey = {


class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastores().main
return hs

def test_room_keys_version_delete(self):
def test_room_keys_version_delete(self) -> None:
# test that deleting a room key backup deletes the keys
version1 = self.get_success(
self.store.create_e2e_room_keys_version(


+ 10
- 5
tests/storage/test_end_to_end_keys.py View File

@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class EndToEndKeyStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def test_key_without_device_name(self):
def test_key_without_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}

@@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)

def test_reupload_key(self):
def test_reupload_key(self) -> None:
now = 1470174257070
json = {"key": "value"}

@@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
)
self.assertFalse(changed)

def test_get_key_with_device_name(self):
def test_get_key_with_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}

@@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)

def test_multiple_devices(self):
def test_multiple_devices(self) -> None:
now = 1470174257070

self.get_success(self.store.store_device("user1", "device1", None))


+ 17
- 12
tests/storage/test_event_chain.py View File

@@ -14,6 +14,7 @@

from typing import Dict, List, Set, Tuple

from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest

from synapse.api.constants import EventTypes
@@ -22,18 +23,22 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.events import _LinkMap
from synapse.storage.types import Cursor
from synapse.types import create_requester
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class EventChainStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._next_stream_ordering = 1

def test_simple(self):
def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
),
)

def test_out_of_order_events(self):
def test_out_of_order_events(self) -> None:
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
@@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def persist(
self,
events: List[EventBase],
):
) -> None:
"""Persist the given events and check that the links generated match
those given.
"""
@@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1

def _persist(txn):
def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
@@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):


class LinkMapTestCase(unittest.TestCase):
def test_simple(self):
def test_simple(self) -> None:
"""Basic tests for the LinkMap."""
link_map = _LinkMap()

@@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):

# Delete the chain cover info.

def _delete_tables(txn):
def _delete_tables(txn: Cursor) -> None:
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")

@@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):

return room_id, [state1, state2]

def test_background_update_single_room(self):
def test_background_update_single_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)

def test_background_update_multiple_rooms(self):
def test_background_update_multiple_rooms(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)

def test_background_update_single_large_room(self):
def test_background_update_single_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)

def test_background_update_multiple_large_room(self):
def test_background_update_multiple_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""


+ 35
- 36
tests/storage/test_event_federation.py View File

@@ -13,7 +13,7 @@
# limitations under the License.

import datetime
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, cast

import attr
from parameterized import parameterized
@@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
from synapse.events import _EventInternalMetadata
from synapse.events import EventBase, _EventInternalMetadata
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder

@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def test_get_prev_events_for_room(self):
def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"

# add a bunch of events and hashes to act as forward extremities
def insert_event(txn, i):
def insert_event(txn: Cursor, i: int) -> None:
event_id = "$event_%i:local" % i

txn.execute(
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])

def test_get_rooms_with_many_extremities(self):
def test_get_rooms_with_many_extremities(self) -> None:
room1 = "#room1"
room2 = "#room2"
room3 = "#room3"

def insert_event(txn, i, room_id):
def insert_event(txn: Cursor, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i
txn.execute(
(
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J

auth_graph = {
auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

# Mark the room as maybe having a cover index.

def store_room(txn):
def store_room(txn: LoggingTransaction) -> None:
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.

def insert_event(txn):
def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0

for event_id in auth_graph:
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
FakeEvent(event_id, room_id, auth_graph[event_id])
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
],
)
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id

@parameterized.expand([(True,), (False,)])
def test_auth_chain_ids(self, use_chain_cover_index: bool):
def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)

# a and b have the same auth chain.
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["i", "j"])

@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
def test_auth_difference(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)

# Now actually test that various combinations give the right result:
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())

def test_auth_difference_partial_cover(self):
def test_auth_difference_partial_cover(self) -> None:
"""Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J

auth_graph = {
auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.

def insert_event(txn):
def insert_event(txn: LoggingTransaction) -> None:
# First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
FakeEvent(event_id, room_id, auth_graph[event_id])
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
],
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[FakeEvent("b", room_id, auth_graph["b"])],
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)

self.store.db_pool.simple_update_txn(
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
)
def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
"""Test that pruning of inbound federation queues work"""

room_id = "some_room_id"
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

stream_ordering += 1

def populate_db(txn: LoggingTransaction):
def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)

def test_get_backfill_points_in_room(self):
def test_get_backfill_points_in_room(self) -> None:
"""
Test to make sure only backfill points that are older and come before
the `current_depth` are returned.
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self,
):
) -> None:
"""
Test to make sure that events we have attempted to backfill (and within
backoff timeout duration) do not show up as an event to backfill again.
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self,
):
) -> None:
"""
Test to make sure after we fake attempt to backfill event "b3" many times,
we can see retry and see the "b3" again after the backoff timeout duration
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"5": 7,
}

def populate_db(txn: LoggingTransaction):
def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)

def test_get_insertion_event_backward_extremities_in_room(self):
def test_get_insertion_event_backward_extremities_in_room(self) -> None:
"""
Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned.
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self,
):
) -> None:
"""
Test to make sure that insertion events we have attempted to backfill
(and within backoff timeout duration) do not show up as an event to
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self,
):
) -> None:
"""
Test to make sure after we fake attempt to backfill event
"insertion_eventA" many times, we can see retry and see the
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])

def test_get_event_ids_to_not_pull_from_backoff(
self,
):
def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
"""
Test to make sure only event IDs we should backoff from are returned.
"""
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):

def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
):
) -> None:
"""
Test to make sure no event IDs are returned after the backoff duration has
elapsed.
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(event_ids_to_backoff, [])


@attr.s
@attr.s(auto_attribs=True)
class FakeEvent:
event_id = attr.ib()
room_id = attr.ib()
auth_events = attr.ib()
event_id: str
room_id: str
auth_events: List[str]

type = "foo"
state_key = "foo"

internal_metadata = _EventInternalMetadata({})

def auth_event_ids(self):
def auth_event_ids(self) -> List[str]:
return self.auth_events

def is_state(self):
def is_state(self) -> bool:
return True

+ 1
- 1
tests/storage/test_event_metrics.py View File

@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase


class ExtremStatisticsTestCase(HomeserverTestCase):
def test_exposed_to_prometheus(self):
def test_exposed_to_prometheus(self) -> None:
"""
Forward extremity counts are exposed via Prometheus.
"""


+ 26
- 13
tests/storage/test_events.py View File

@@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import StateMap
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
@@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that the current extremities is the remote event.
self.assert_extremities([self.remote_event_1.event_id])

def persist_event(self, event, state=None):
def persist_event(
self, event: EventBase, state: Optional[StateMap[str]] = None
) -> None:
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(
@@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
self.get_success(self._persistence.persist_event(event, context))

def assert_extremities(self, expected_extremities):
def assert_extremities(self, expected_extremities: List[str]) -> None:
"""Assert the current extremities for the room"""
extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id)
)
self.assertCountEqual(extremities, expected_extremities)

def test_prune_gap(self):
def test_prune_gap(self) -> None:
"""Test that we drop extremities after a gap when we see an event from
the same domain.
"""
@@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])

def test_do_not_prune_gap_if_state_different(self):
def test_do_not_prune_gap_if_state_different(self) -> None:
"""Test that we don't prune extremities after a gap if the resolved
state is different.
"""
@@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])

def test_prune_gap_if_old(self):
def test_prune_gap_if_old(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is "old"
"""
@@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])

def test_do_not_prune_gap_if_other_server(self):
def test_do_not_prune_gap_if_other_server(self) -> None:
"""Test that we do not drop extremities after a gap when we see an event
from a different domain.
"""
@@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])

def test_prune_gap_if_dummy_remote(self):
def test_prune_gap_if_dummy_remote(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is a local dummy event and only points to remote events.
"""
@@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])

def test_prune_gap_if_dummy_local(self):
def test_prune_gap_if_dummy_local(self) -> None:
"""Test that we don't drop extremities after a gap when the previous
extremity is a local dummy event and points to local events.
"""
@@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])

def test_do_not_prune_gap_if_not_dummy(self):
def test_do_not_prune_gap_if_not_dummy(self) -> None:
"""Test that we do not drop extremities after a gap when the previous extremity
is not a dummy event.
"""
@@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main

def test_remote_user_rooms_cache_invalidated(self):
def test_remote_user_rooms_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_rooms_for_user` cache
is invalidated for remote users.
"""
@@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
self.assertEqual(set(rooms), set())

def test_room_remote_user_cache_invalidated(self):
def test_room_remote_user_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_users_in_room` cache
is invalidated for remote users.
"""


+ 6
- 3
tests/storage/test_keys.py View File

@@ -13,6 +13,7 @@
# limitations under the License.

import signedjson.key
import signedjson.types
import unpaddedbase64

from twisted.internet.defer import Deferred
@@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest


def decode_verify_key_base64(key_id: str, key_base64: str):
def decode_verify_key_base64(
key_id: str, key_base64: str
) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)

@@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(


class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
def test_get_server_verify_keys(self) -> None:
store = self.hs.get_datastores().main

key_id_1 = "ed25519:key1"
@@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])

def test_cache(self):
def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache."""

store = self.hs.get_datastores().main


+ 15
- 15
tests/storage/test_monthly_active_users.py View File

@@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.reactor.advance(FORTY_DAYS)

@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
def test_initialise_reserved_users(self) -> None:
threepids = self.hs.config.server.mau_limits_reserved_threepids

# register three users, of which two have reserved 3pids, and a third
@@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(active_count, 3)

def test_can_insert_and_count_mau(self):
def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)

@@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)

def test_appservice_user_not_counted_in_mau(self):
def test_appservice_user_not_counted_in_mau(self) -> None:
self.get_success(
self.store.register_user(
user_id="@appservice_user:server", appservice_id="wibble"
@@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)

def test_user_last_seen_monthly_active(self):
def test_user_last_seen_monthly_active(self) -> None:
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertIsNone(result)

@override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
def test_reap_monthly_active_users(self) -> None:
initial_users = 10
for i in range(initial_users):
self.get_success(
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
def test_reap_monthly_active_users_reserved_users(self) -> None:
"""Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.server.mau_limits_reserved_threepids
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, self.hs.config.server.max_mau_value)

def test_populate_monthly_users_is_guest(self):
def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list
user_id = "@user_id:host"

@@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):

self.store.upsert_monthly_active_user.assert_not_called()

def test_populate_monthly_users_should_update(self):
def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]

self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):

self.store.upsert_monthly_active_user.assert_called_once()

def test_populate_monthly_users_should_not_update(self):
def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]

self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):

self.store.upsert_monthly_active_user.assert_not_called()

def test_get_reserved_real_user_account(self):
def test_get_reserved_real_user_account(self) -> None:
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), 0)
@@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))

def test_support_user_not_add_to_mau_limits(self):
def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test"

count = self.get_success(self.store.get_monthly_active_count())
@@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
def test_track_monthly_users_without_cap(self):
def test_track_monthly_users_without_cap(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(0, count)

@@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(2, count)

@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]

self.get_success(self.store.populate_monthly_active_users("@user:sever"))

self.store.upsert_monthly_active_user.assert_not_called()

def test_get_monthly_active_count_by_service(self):
def test_get_monthly_active_count_by_service(self) -> None:
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"

@@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)

def test_get_monthly_active_users_by_service(self):
def test_get_monthly_active_users_by_service(self) -> None:
# (No users, no filtering) -> empty result
result = self.get_success(self.store.get_monthly_active_users_by_service())



+ 10
- 5
tests/storage/test_purge.py View File

@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)

self.store = hs.get_datastores().main
self._storage_controllers = self.hs.get_storage_controllers()

def test_purge_history(self):
def test_purge_history(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))

def test_purge_history_wont_delete_extrems(self):
def test_purge_history_wont_delete_extrems(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"])
)
assert token.topological is not None
event = f"t{token.topological + 1}-{token.stream + 1}"

# Purge everything before this topological token
@@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))

def test_purge_room(self):
def test_purge_room(self) -> None:
"""
Purging a room will delete everything about it.
"""


+ 9
- 3
tests/storage/test_receipts.py View File

@@ -14,8 +14,12 @@

from typing import Collection, Optional

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import ReceiptTypes
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock

from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase
@@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"


class ReceiptTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver) -> None:
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
super().prepare(reactor, clock, homeserver)

self.store = homeserver.get_datastores().main
@@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})

res = self.get_last_unthreaded_receipt(
res2 = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)

self.assertEqual(res, None)
self.assertIsNone(res2)

def test_get_receipts_for_user(self) -> None:
# Send some events into the first room


+ 69
- 56
tests/storage/test_redaction.py View File

@@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import List, Optional, cast

from canonicaljson import json

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
from synapse.events import EventBase, _EventInternalMetadata
from synapse.events.builder import EventBuilder
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomID, UserID
from synapse.util import Clock

from tests import unittest
from tests.utils import create_room


class RedactionTestCase(unittest.HomeserverTestCase):
def default_config(self):
def default_config(self) -> JsonDict:
config = super().default_config()
config["redaction_retention_period"] = "30d"
return config

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._storage = hs.get_storage_controllers()
storage = hs.get_storage_controllers()
assert storage.persistence is not None
self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()

@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):

self.depth = 1

def inject_room_member(
def inject_room_member( # type: ignore[override]
self,
room,
user,
membership,
replaces_state=None,
extra_content: Optional[dict] = None,
):
room: RoomID,
user: UserID,
membership: str,
extra_content: Optional[JsonDict] = None,
) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)

self.get_success(self._storage.persistence.persist_event(event, context))
self.get_success(self._persistence.persist_event(event, context))

return event

def inject_message(self, room, user, body):
def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1

builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)

self.get_success(self._storage.persistence.persist_event(event, context))
self.get_success(self._persistence.persist_event(event, context))

return event

def inject_redaction(self, room, event_id, user, reason):
def inject_redaction(
self, room: RoomID, event_id: str, user: UserID, reason: str
) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)

self.get_success(self._storage.persistence.persist_event(event, context))
self.get_success(self._persistence.persist_event(event, context))

return event

def test_redact(self):
def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)

msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)

def test_redact_join(self):
def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)

msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)

def test_circular_redaction(self):
def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"

class EventIdManglingBuilder:
def __init__(self, base_builder, event_id):
def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder
self._event_id = event_id

@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
):
) -> EventBase:
built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)

built_event._event_id = self._event_id
built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id

return built_event

@property
def room_id(self):
def room_id(self) -> str:
return self._base_builder.room_id

@property
def type(self):
def type(self) -> str:
return self._base_builder.type

@property
def internal_metadata(self):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata

event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id2,
},
cast(
EventBuilder,
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id2,
},
),
redaction_event_id1,
),
redaction_event_id1,
)
)
)

self.get_success(self._storage.persistence.persist_event(event_1, context_1))
self.get_success(self._persistence.persist_event(event_1, context_1))

event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id1,
},
cast(
EventBuilder,
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id1,
},
),
redaction_event_id2,
),
redaction_event_id2,
)
)
)
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
self.get_success(self._persistence.persist_event(event_2, context_2))

# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)

def test_redact_censor(self):
def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month"""

self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):

self.assert_dict({"content": {}}, json.loads(event_json))

def test_redact_redaction(self):
def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again."""

self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True)
)

def test_store_redacted_redaction(self):
def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction."""

self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)

self.get_success(
self._storage.persistence.persist_event(redaction_event, context)
)
self.get_success(self._persistence.persist_event(redaction_event, context))

# Now lets jump to the future where we have censored the redaction event
# in the DB.


+ 10
- 5
tests/storage/test_rollback_worker.py View File

@@ -14,10 +14,15 @@
from typing import List
from unittest import mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.app.generic_worker import GenericWorkerServer
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
from synapse.storage.schema import SCHEMA_VERSION
from synapse.types import JsonDict
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:


class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs

def default_config(self):
def default_config(self) -> JsonDict:
conf = super().default_config()

# Mark this as a worker app.
@@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):

return conf

def test_rolling_back(self):
def test_rolling_back(self) -> None:
"""Test that workers can start if the DB is a newer schema version"""

db_pool = self.hs.get_datastores().main.db_pool
@@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):

prepare_database(db_conn, db_pool.engine, self.hs.config)

def test_not_upgraded_old_schema_version(self):
def test_not_upgraded_old_schema_version(self) -> None:
"""Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastores().main.db_pool
db_conn = LoggingDatabaseConnection(
@@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config)

def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
"""
Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run.


+ 16
- 8
tests/storage/test_room.py View File

@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.room_versions import RoomVersions
from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID, UserID
from synapse.util import Clock

from tests.unittest import HomeserverTestCase


class RoomStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastores().main
@@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
)
)

def test_get_room(self):
def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
},
(self.get_success(self.store.get_room(self.room.to_string()))),
res,
)

def test_get_room_unknown_room(self):
def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))

def test_get_room_with_stats(self):
def test_get_room_with_stats(self) -> None:
res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"public": True,
},
(self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
res,
)

def test_get_room_with_stats_unknown_room(self):
def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone(
(self.get_success(self.store.get_room_with_stats("!uknown:test"))),
self.get_success(self.store.get_room_with_stats("!uknown:test"))
)

+ 5
- 5
tests/storage/test_room_search.py View File

@@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
room.register_servlets,
]

def test_null_byte(self):
def test_null_byte(self) -> None:
"""
Postgres/SQLite don't like null bytes going into the search tables. Internally
we replace those with a space.
@@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights"))

def test_non_string(self):
def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.

This is particularly important when using sqlite, since a sqlite column can hold
@@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
self.assertEqual(f.value.code, 404)

@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
def test_sqlite_non_string_deletion_background_update(self):
def test_sqlite_non_string_deletion_background_update(self) -> None:
"""Test the background update to delete bad rows from `event_search`."""
store = self.hs.get_datastores().main

@@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
"results array length should match count",
)

def test_postgres_web_search_for_phrase(self):
def test_postgres_web_search_for_phrase(self) -> None:
"""
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
This test is skipped unless the postgres instance supports websearch_to_tsquery.
@@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):

self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)

def test_sqlite_search(self):
def test_sqlite_search(self) -> None:
"""
Test sqlite searching for phrases.
"""


+ 30
- 16
tests/storage/test_state.py View File

@@ -16,10 +16,15 @@ import logging

from frozendict import frozendict

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.server import HomeServer
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.util import Clock

from tests.unittest import HomeserverTestCase, TestCase

@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)


class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)

def inject_state_event(self, room, sender, typ, state_key, content):
def inject_state_event(
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)

assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))

return event

def assertStateMapEqual(self, s1, s2):
def assertStateMapEqual(
self, s1: StateMap[EventBase], s2: StateMap[EventBase]
) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))

def test_get_state_groups_ids(self):
def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)

state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
self.storage.state.get_state_groups_ids(
self.room.to_string(), [e2.event_id]
)
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)

def test_get_state_groups(self):
def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)

state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]

self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})

def test_get_state_for_event(self):
def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
):
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)

def test_state_filter_difference_no_include_other_minus_no_include_other(self):
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)

def test_state_filter_difference_include_other_minus_no_include_other(self):
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)

def test_state_filter_difference_include_other_minus_include_other(self):
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)

def test_state_filter_difference_no_include_other_minus_include_other(self):
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)

def test_state_filter_difference_simple_cases(self):
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):


class StateFilterTestCase(TestCase):
def test_return_expanded(self):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).


+ 12
- 6
tests/storage/test_stream.py View File

@@ -14,11 +14,15 @@

from typing import List

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

@@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
login.register_servlets,
]

def default_config(self):
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc3874_enabled": True}
return config

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):

return [ev.event_id for ev in events]

def test_filter_relation_senders(self):
def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])

def test_filter_relation_type(self):
def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])

def test_filter_relation_senders_and_type(self):
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"related_by_senders": [self.second_user_id],
@@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1])

def test_duplicate_relation(self):
def test_duplicate_relation(self) -> None:
"""An event should only be returned once if there are multiple relations to it."""
self.helper.send_event(
room_id=self.room_id,


+ 12
- 6
tests/storage/test_transactions.py View File

@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings
from synapse.util import Clock
from synapse.util.retryutils import MAX_RETRY_INTERVAL

from tests.unittest import HomeserverTestCase


class TransactionStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main

def test_get_set_transactions(self):
def test_get_set_transactions(self) -> None:
"""Tests that we can successfully get a non-existent entry for
destination retries, as well as testing tht we can set and get
correctly.
@@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
r,
)

def test_initial_set_transactions(self):
def test_initial_set_transactions(self) -> None:
"""Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this)
"""
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)

def test_large_destination_retry(self):
def test_large_destination_retry(self) -> None:
d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
)
self.get_success(d)

d = self.store.get_destination_retry_timings("example.com")
self.get_success(d)
d2 = self.store.get_destination_retry_timings("example.com")
self.get_success(d2)

+ 10
- 4
tests/storage/test_txn_limit.py View File

@@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.types import Cursor
from synapse.util import Clock

from tests import unittest


class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
"""Test SQL transaction limit doesn't break transactions."""

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(db_txn_limit=1000)

def test_config(self):
def test_config(self) -> None:
db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)

def test_select(self):
def do_select(txn):
def test_select(self) -> None:
def do_select(txn: Cursor) -> None:
txn.execute("SELECT 1")

db_pool = self.hs.get_datastores().databases[0]


+ 15
- 15
tests/storage/util/test_partial_state_events_tracker.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Collection, Dict
from unittest import mock

from twisted.internet.defer import CancelledError, ensureDeferred
@@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}

async def get_partial_state_events(events):
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
return {e: self._events_dict[e] for e in events}

self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
@@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):

self.tracker = PartialStateEventsTracker(self.mock_store)

def test_does_not_block_for_full_state_events(self):
def test_does_not_block_for_full_state_events(self) -> None:
self._events_dict = {"event1": False, "event2": False}

self.successResultOf(
@@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
["event1", "event2"]
)

def test_blocks_for_partial_state_events(self):
def test_blocks_for_partial_state_events(self) -> None:
self._events_dict = {"event1": True, "event2": False}

d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)

def test_un_partial_state_race(self):
def test_un_partial_state_race(self) -> None:
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}

async def get_partial_state_events(events):
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
@@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

def test_un_partial_state_during_get_partial_state_events(self):
def test_un_partial_state_during_get_partial_state_events(self) -> None:
# we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events.

self._events_dict = {"event1": True, "event2": False}

async def get_partial_state_events1(events):
async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2
)
return {e: self._events_dict[e] for e in events}

async def get_partial_state_events2(events):
async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events}
@@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

def test_cancellation(self):
def test_cancellation(self) -> None:
self._events_dict = {"event1": True, "event2": False}

d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):

self.tracker = PartialCurrentStateTracker(self.mock_store)

def test_does_not_block_for_full_state_rooms(self):
def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)

self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))

def test_blocks_for_partial_room_state(self):
def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)

d = ensureDeferred(self.tracker.await_full_state("room_id"))
@@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("room_id")
self.successResultOf(d)

def test_un_partial_state_race(self):
def test_un_partial_state_race(self) -> None:
# We should correctly handle race between awaiting the state and us
# un-partialling the state
async def is_partial_state_room(events):
async def is_partial_state_room(room_id: str) -> bool:
self.tracker.notify_un_partial_stated("room_id")
return True

@@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):

self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))

def test_cancellation(self):
def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)

d1 = ensureDeferred(self.tracker.await_full_state("room_id"))


Loading…
Cancel
Save