Removes files under tests.util from the ignored by list, then fully types all tests/util/*.py files.tags/v1.74.0rc1
@@ -0,0 +1 @@ | |||
Add missing type hints. |
@@ -59,16 +59,6 @@ exclude = (?x) | |||
|tests/server_notices/test_resource_limits_server_notices.py | |||
|tests/test_state.py | |||
|tests/test_terms_auth.py | |||
|tests/util/test_async_helpers.py | |||
|tests/util/test_batching_queue.py | |||
|tests/util/test_dict_cache.py | |||
|tests/util/test_expiring_cache.py | |||
|tests/util/test_file_consumer.py | |||
|tests/util/test_linearizer.py | |||
|tests/util/test_logcontext.py | |||
|tests/util/test_lrucache.py | |||
|tests/util/test_rwlock.py | |||
|tests/util/test_wheel_timer.py | |||
)$ | |||
[mypy-synapse.federation.transport.client] | |||
@@ -137,6 +127,9 @@ disallow_untyped_defs = True | |||
[mypy-tests.util.caches.test_descriptors] | |||
disallow_untyped_defs = False | |||
[mypy-tests.util.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.utils] | |||
disallow_untyped_defs = True | |||
@@ -12,6 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import traceback | |||
from typing import Generator, List, NoReturn, Optional | |||
from parameterized import parameterized_class | |||
@@ -41,8 +42,8 @@ from tests.unittest import TestCase | |||
class ObservableDeferredTest(TestCase): | |||
def test_succeed(self): | |||
origin_d = Deferred() | |||
def test_succeed(self) -> None: | |||
origin_d: "Deferred[int]" = Deferred() | |||
observable = ObservableDeferred(origin_d) | |||
observer1 = observable.observe() | |||
@@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase): | |||
self.assertFalse(observer2.called) | |||
# check the first observer is called first | |||
def check_called_first(res): | |||
def check_called_first(res: int) -> int: | |||
self.assertFalse(observer2.called) | |||
return res | |||
observer1.addBoth(check_called_first) | |||
# store the results | |||
results = [None, None] | |||
results: List[Optional[ObservableDeferred[int]]] = [None, None] | |||
def check_val(res, idx): | |||
def check_val( | |||
res: ObservableDeferred[int], idx: int | |||
) -> ObservableDeferred[int]: | |||
results[idx] = res | |||
return res | |||
@@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase): | |||
self.assertEqual(results[0], 123, "observer 1 callback result") | |||
self.assertEqual(results[1], 123, "observer 2 callback result") | |||
def test_failure(self): | |||
origin_d = Deferred() | |||
def test_failure(self) -> None: | |||
origin_d: Deferred = Deferred() | |||
observable = ObservableDeferred(origin_d, consumeErrors=True) | |||
observer1 = observable.observe() | |||
@@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase): | |||
self.assertFalse(observer2.called) | |||
# check the first observer is called first | |||
def check_called_first(res): | |||
def check_called_first(res: int) -> int: | |||
self.assertFalse(observer2.called) | |||
return res | |||
observer1.addBoth(check_called_first) | |||
# store the results | |||
results = [None, None] | |||
results: List[Optional[ObservableDeferred[str]]] = [None, None] | |||
def check_val(res, idx): | |||
def check_val(res: ObservableDeferred[str], idx: int) -> None: | |||
results[idx] = res | |||
return None | |||
@@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase): | |||
raise Exception("gah!") | |||
except Exception as e: | |||
origin_d.errback(e) | |||
assert results[0] is not None | |||
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") | |||
assert results[1] is not None | |||
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") | |||
def test_cancellation(self): | |||
def test_cancellation(self) -> None: | |||
"""Test that cancelling an observer does not affect other observers.""" | |||
origin_d: "Deferred[int]" = Deferred() | |||
observable = ObservableDeferred(origin_d, consumeErrors=True) | |||
@@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase): | |||
class TimeoutDeferredTest(TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.clock = Clock() | |||
def test_times_out(self): | |||
def test_times_out(self) -> None: | |||
"""Basic test case that checks that the original deferred is cancelled and that | |||
the timing-out deferred is errbacked | |||
""" | |||
cancelled = [False] | |||
cancelled = False | |||
def canceller(_d): | |||
cancelled[0] = True | |||
def canceller(_d: Deferred) -> None: | |||
nonlocal cancelled | |||
cancelled = True | |||
non_completing_d = Deferred(canceller) | |||
non_completing_d: Deferred = Deferred(canceller) | |||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) | |||
self.assertNoResult(timing_out_d) | |||
self.assertFalse(cancelled[0], "deferred was cancelled prematurely") | |||
self.assertFalse(cancelled, "deferred was cancelled prematurely") | |||
self.clock.pump((1.0,)) | |||
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") | |||
self.assertTrue(cancelled, "deferred was not cancelled by timeout") | |||
self.failureResultOf(timing_out_d, defer.TimeoutError) | |||
def test_times_out_when_canceller_throws(self): | |||
def test_times_out_when_canceller_throws(self) -> None: | |||
"""Test that we have successfully worked around | |||
https://twistedmatrix.com/trac/ticket/9534""" | |||
def canceller(_d): | |||
def canceller(_d: Deferred) -> None: | |||
raise Exception("can't cancel this deferred") | |||
non_completing_d = Deferred(canceller) | |||
non_completing_d: Deferred = Deferred(canceller) | |||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) | |||
self.assertNoResult(timing_out_d) | |||
@@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase): | |||
self.failureResultOf(timing_out_d, defer.TimeoutError) | |||
def test_logcontext_is_preserved_on_cancellation(self): | |||
blocking_was_cancelled = [False] | |||
def test_logcontext_is_preserved_on_cancellation(self) -> None: | |||
blocking_was_cancelled = False | |||
@defer.inlineCallbacks | |||
def blocking(): | |||
non_completing_d = Deferred() | |||
def blocking() -> Generator["Deferred[object]", object, None]: | |||
nonlocal blocking_was_cancelled | |||
non_completing_d: Deferred = Deferred() | |||
with PreserveLoggingContext(): | |||
try: | |||
yield non_completing_d | |||
except CancelledError: | |||
blocking_was_cancelled[0] = True | |||
blocking_was_cancelled = True | |||
raise | |||
with LoggingContext("one") as context_one: | |||
# the errbacks should be run in the test logcontext | |||
def errback(res, deferred_name): | |||
def errback(res: Failure, deferred_name: str) -> Failure: | |||
self.assertIs( | |||
current_context(), | |||
context_one, | |||
@@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase): | |||
self.clock.pump((1.0,)) | |||
self.assertTrue( | |||
blocking_was_cancelled[0], "non-completing deferred was not cancelled" | |||
blocking_was_cancelled, "non-completing deferred was not cancelled" | |||
) | |||
self.failureResultOf(timing_out_d, defer.TimeoutError) | |||
self.assertIs(current_context(), context_one) | |||
@@ -220,13 +228,13 @@ class _TestException(Exception): | |||
class ConcurrentlyExecuteTest(TestCase): | |||
def test_limits_runners(self): | |||
def test_limits_runners(self) -> None: | |||
"""If we have more tasks than runners, we should get the limit of runners""" | |||
started = 0 | |||
waiters = [] | |||
processed = [] | |||
async def callback(v): | |||
async def callback(v: int) -> None: | |||
# when we first enter, bump the start count | |||
nonlocal started | |||
started += 1 | |||
@@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase): | |||
processed.append(v) | |||
# wait for the goahead before returning | |||
d2 = Deferred() | |||
d2: "Deferred[int]" = Deferred() | |||
waiters.append(d2) | |||
await d2 | |||
@@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase): | |||
self.assertCountEqual(processed, [1, 2, 3, 4, 5]) | |||
self.successResultOf(d2) | |||
def test_preserves_stacktraces(self): | |||
def test_preserves_stacktraces(self) -> None: | |||
"""Test that the stacktrace from an exception thrown in the callback is preserved""" | |||
d1 = Deferred() | |||
d1: "Deferred[int]" = Deferred() | |||
async def callback(v): | |||
async def callback(v: int) -> None: | |||
# alas, this doesn't work at all without an await here | |||
await d1 | |||
raise _TestException("bah") | |||
async def caller(): | |||
async def caller() -> None: | |||
try: | |||
await concurrently_execute(callback, [1], 2) | |||
except _TestException as e: | |||
@@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase): | |||
d1.callback(0) | |||
self.successResultOf(d2) | |||
def test_preserves_stacktraces_on_preformed_failure(self): | |||
def test_preserves_stacktraces_on_preformed_failure(self) -> None: | |||
"""Test that the stacktrace on a Failure returned by the callback is preserved""" | |||
d1 = Deferred() | |||
d1: "Deferred[int]" = Deferred() | |||
f = Failure(_TestException("bah")) | |||
async def callback(v): | |||
async def callback(v: int) -> None: | |||
# alas, this doesn't work at all without an await here | |||
await d1 | |||
await defer.fail(f) | |||
async def caller(): | |||
async def caller() -> None: | |||
try: | |||
await concurrently_execute(callback, [1], 2) | |||
except _TestException as e: | |||
@@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase): | |||
else: | |||
raise ValueError(f"Unsupported wrapper type: {self.wrapper}") | |||
def test_succeed(self): | |||
def test_succeed(self) -> None: | |||
"""Test that the new `Deferred` receives the result.""" | |||
deferred: "Deferred[str]" = Deferred() | |||
wrapper_deferred = self.wrap_deferred(deferred) | |||
@@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase): | |||
self.assertTrue(wrapper_deferred.called) | |||
self.assertEqual("success", self.successResultOf(wrapper_deferred)) | |||
def test_failure(self): | |||
def test_failure(self) -> None: | |||
"""Test that the new `Deferred` receives the `Failure`.""" | |||
deferred: "Deferred[str]" = Deferred() | |||
wrapper_deferred = self.wrap_deferred(deferred) | |||
@@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase): | |||
class StopCancellationTests(TestCase): | |||
"""Tests for the `stop_cancellation` function.""" | |||
def test_cancellation(self): | |||
def test_cancellation(self) -> None: | |||
"""Test that cancellation of the new `Deferred` leaves the original running.""" | |||
deferred: "Deferred[str]" = Deferred() | |||
wrapper_deferred = stop_cancellation(deferred) | |||
@@ -384,7 +392,7 @@ class StopCancellationTests(TestCase): | |||
class DelayCancellationTests(TestCase): | |||
"""Tests for the `delay_cancellation` function.""" | |||
def test_deferred_cancellation(self): | |||
def test_deferred_cancellation(self) -> None: | |||
"""Test that cancellation of the new `Deferred` waits for the original.""" | |||
deferred: "Deferred[str]" = Deferred() | |||
wrapper_deferred = delay_cancellation(deferred) | |||
@@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase): | |||
# Now that the original `Deferred` has failed, we should get a `CancelledError`. | |||
self.failureResultOf(wrapper_deferred, CancelledError) | |||
def test_coroutine_cancellation(self): | |||
def test_coroutine_cancellation(self) -> None: | |||
"""Test that cancellation of the new `Deferred` waits for the original.""" | |||
blocking_deferred: "Deferred[None]" = Deferred() | |||
completion_deferred: "Deferred[None]" = Deferred() | |||
async def task(): | |||
async def task() -> NoReturn: | |||
await blocking_deferred | |||
completion_deferred.callback(None) | |||
# Raise an exception. Twisted should consume it, otherwise unwanted | |||
@@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase): | |||
# Now that the original coroutine has failed, we should get a `CancelledError`. | |||
self.failureResultOf(wrapper_deferred, CancelledError) | |||
def test_suppresses_second_cancellation(self): | |||
def test_suppresses_second_cancellation(self) -> None: | |||
"""Test that a second cancellation is suppressed. | |||
Identical to `test_cancellation` except the new `Deferred` is cancelled twice. | |||
@@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase): | |||
# Now that the original `Deferred` has failed, we should get a `CancelledError`. | |||
self.failureResultOf(wrapper_deferred, CancelledError) | |||
def test_propagates_cancelled_error(self): | |||
def test_propagates_cancelled_error(self) -> None: | |||
"""Test that a `CancelledError` from the original `Deferred` gets propagated.""" | |||
deferred: "Deferred[str]" = Deferred() | |||
wrapper_deferred = delay_cancellation(deferred) | |||
@@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase): | |||
self.assertTrue(wrapper_deferred.called) | |||
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) | |||
def test_preserves_logcontext(self): | |||
def test_preserves_logcontext(self) -> None: | |||
"""Test that logging contexts are preserved.""" | |||
blocking_d: "Deferred[None]" = Deferred() | |||
async def inner(): | |||
async def inner() -> None: | |||
await make_deferred_yieldable(blocking_d) | |||
async def outer(): | |||
async def outer() -> None: | |||
with LoggingContext("c") as c: | |||
try: | |||
await delay_cancellation(inner()) | |||
@@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase): | |||
class AwakenableSleeperTests(TestCase): | |||
"Tests AwakenableSleeper" | |||
def test_sleep(self): | |||
def test_sleep(self) -> None: | |||
reactor, _ = get_clock() | |||
sleeper = AwakenableSleeper(reactor) | |||
@@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase): | |||
reactor.advance(0.6) | |||
self.assertTrue(d.called) | |||
def test_explicit_wake(self): | |||
def test_explicit_wake(self) -> None: | |||
reactor, _ = get_clock() | |||
sleeper = AwakenableSleeper(reactor) | |||
@@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase): | |||
reactor.advance(0.6) | |||
def test_multiple_sleepers_timeout(self): | |||
def test_multiple_sleepers_timeout(self) -> None: | |||
reactor, _ = get_clock() | |||
sleeper = AwakenableSleeper(reactor) | |||
@@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase): | |||
reactor.advance(0.6) | |||
self.assertTrue(d2.called) | |||
def test_multiple_sleepers_wake(self): | |||
def test_multiple_sleepers_wake(self) -> None: | |||
reactor, _ = get_clock() | |||
sleeper = AwakenableSleeper(reactor) | |||
@@ -11,6 +11,10 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List, Tuple | |||
from prometheus_client import Gauge | |||
from twisted.internet import defer | |||
from synapse.logging.context import make_deferred_yieldable | |||
@@ -26,7 +30,7 @@ from tests.unittest import TestCase | |||
class BatchingQueueTestCase(TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.clock, hs_clock = get_clock() | |||
# We ensure that we remove any existing metrics for "test_queue". | |||
@@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase): | |||
except KeyError: | |||
pass | |||
self._pending_calls = [] | |||
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) | |||
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] | |||
self.queue: BatchingQueue[str, str] = BatchingQueue( | |||
"test_queue", hs_clock, self._process_queue | |||
) | |||
async def _process_queue(self, values): | |||
d = defer.Deferred() | |||
async def _process_queue(self, values: List[str]) -> str: | |||
d: "defer.Deferred[str]" = defer.Deferred() | |||
self._pending_calls.append((values, d)) | |||
return await make_deferred_yieldable(d) | |||
def _get_sample_with_name(self, metric, name) -> int: | |||
def _get_sample_with_name(self, metric: Gauge, name: str) -> float: | |||
"""For a prometheus metric get the value of the sample that has a | |||
matching "name" label. | |||
""" | |||
for sample in metric.collect()[0].samples: | |||
for sample in next(iter(metric.collect())).samples: | |||
if sample.labels.get("name") == name: | |||
return sample.value | |||
self.fail("Found no matching sample") | |||
def _assert_metrics(self, queued, keys, in_flight): | |||
def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None: | |||
"""Assert that the metrics are correct""" | |||
sample = self._get_sample_with_name(number_queued, self.queue._name) | |||
@@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase): | |||
"number_in_flight", | |||
) | |||
def test_simple(self): | |||
def test_simple(self) -> None: | |||
"""Tests the basic case of calling `add_to_queue` once and having | |||
`_process_queue` return. | |||
""" | |||
@@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase): | |||
self._assert_metrics(queued=0, keys=0, in_flight=0) | |||
def test_batching(self): | |||
def test_batching(self) -> None: | |||
"""Test that multiple calls at the same time get batched up into one | |||
call to `_process_queue`. | |||
""" | |||
@@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase): | |||
self.assertEqual(self.successResultOf(queue_d2), "bar") | |||
self._assert_metrics(queued=0, keys=0, in_flight=0) | |||
def test_queuing(self): | |||
def test_queuing(self) -> None: | |||
"""Test that we queue up requests while a `_process_queue` is being | |||
called. | |||
""" | |||
@@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase): | |||
self.assertEqual(self.successResultOf(queue_d3), "bar2") | |||
self._assert_metrics(queued=0, keys=0, in_flight=0) | |||
def test_different_keys(self): | |||
def test_different_keys(self) -> None: | |||
"""Test that calls to different keys get processed in parallel.""" | |||
self.assertFalse(self._pending_calls) | |||
@@ -1,5 +1,20 @@ | |||
# Copyright 2022 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from contextlib import contextmanager | |||
from typing import Generator, Optional | |||
from os import PathLike | |||
from typing import Generator, Optional, Union | |||
from unittest.mock import patch | |||
from synapse.util.check_dependencies import ( | |||
@@ -12,17 +27,17 @@ from tests.unittest import TestCase | |||
class DummyDistribution(metadata.Distribution): | |||
def __init__(self, version: object): | |||
def __init__(self, version: str): | |||
self._version = version | |||
@property | |||
def version(self): | |||
def version(self) -> str: | |||
return self._version | |||
def locate_file(self, path): | |||
def locate_file(self, path: Union[str, PathLike]) -> PathLike: | |||
raise NotImplementedError() | |||
def read_text(self, filename): | |||
def read_text(self, filename: str) -> None: | |||
raise NotImplementedError() | |||
@@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2") | |||
old_release_candidate = DummyDistribution("0.1.2rc3") | |||
new = DummyDistribution("1.2.3") | |||
new_release_candidate = DummyDistribution("1.2.3rc4") | |||
distribution_with_no_version = DummyDistribution(None) | |||
distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type] | |||
# could probably use stdlib TestCase --- no need for twisted here | |||
@@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase): | |||
If `distribution = None`, we pretend that the package is not installed. | |||
""" | |||
def mock_distribution(name: str): | |||
def mock_distribution(name: str) -> DummyDistribution: | |||
if distribution is None: | |||
raise metadata.PackageNotFoundError | |||
else: | |||
@@ -19,10 +19,12 @@ from tests import unittest | |||
class DictCacheTestCase(unittest.TestCase): | |||
def setUp(self): | |||
self.cache = DictionaryCache("foobar", max_entries=10) | |||
def setUp(self) -> None: | |||
self.cache: DictionaryCache[str, str, str] = DictionaryCache( | |||
"foobar", max_entries=10 | |||
) | |||
def test_simple_cache_hit_full(self): | |||
def test_simple_cache_hit_full(self) -> None: | |||
key = "test_simple_cache_hit_full" | |||
v = self.cache.get(key) | |||
@@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
c = self.cache.get(key) | |||
self.assertEqual(test_value, c.value) | |||
def test_simple_cache_hit_partial(self): | |||
def test_simple_cache_hit_partial(self) -> None: | |||
key = "test_simple_cache_hit_partial" | |||
seq = self.cache.sequence | |||
@@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
c = self.cache.get(key, ["test"]) | |||
self.assertEqual(test_value, c.value) | |||
def test_simple_cache_miss_partial(self): | |||
def test_simple_cache_miss_partial(self) -> None: | |||
key = "test_simple_cache_miss_partial" | |||
seq = self.cache.sequence | |||
@@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
c = self.cache.get(key, ["test2"]) | |||
self.assertEqual({}, c.value) | |||
def test_simple_cache_hit_miss_partial(self): | |||
def test_simple_cache_hit_miss_partial(self) -> None: | |||
key = "test_simple_cache_hit_miss_partial" | |||
seq = self.cache.sequence | |||
@@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
c = self.cache.get(key, ["test2"]) | |||
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) | |||
def test_multi_insert(self): | |||
def test_multi_insert(self) -> None: | |||
key = "test_simple_cache_hit_miss_partial" | |||
seq = self.cache.sequence | |||
@@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
) | |||
self.assertEqual(c.full, False) | |||
def test_invalidation(self): | |||
def test_invalidation(self) -> None: | |||
"""Test that the partial dict and full dicts get invalidated | |||
separately. | |||
""" | |||
@@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase): | |||
# entry for "a" warm. | |||
for i in range(20): | |||
self.cache.get(key, ["a"]) | |||
self.cache.update(seq, f"key{i}", {1: 2}) | |||
self.cache.update(seq, f"key{i}", {"1": "2"}) | |||
# We should have evicted the full dict... | |||
r = self.cache.get(key) | |||
@@ -12,7 +12,9 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List, cast | |||
from synapse.util import Clock | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
from tests.utils import MockClock | |||
@@ -21,17 +23,21 @@ from .. import unittest | |||
class ExpiringCacheTestCase(unittest.HomeserverTestCase): | |||
def test_get_set(self): | |||
def test_get_set(self) -> None: | |||
clock = MockClock() | |||
cache = ExpiringCache("test", clock, max_len=1) | |||
cache: ExpiringCache[str, str] = ExpiringCache( | |||
"test", cast(Clock, clock), max_len=1 | |||
) | |||
cache["key"] = "value" | |||
self.assertEqual(cache.get("key"), "value") | |||
self.assertEqual(cache["key"], "value") | |||
def test_eviction(self): | |||
def test_eviction(self) -> None: | |||
clock = MockClock() | |||
cache = ExpiringCache("test", clock, max_len=2) | |||
cache: ExpiringCache[str, str] = ExpiringCache( | |||
"test", cast(Clock, clock), max_len=2 | |||
) | |||
cache["key"] = "value" | |||
cache["key2"] = "value2" | |||
@@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(cache.get("key2"), "value2") | |||
self.assertEqual(cache.get("key3"), "value3") | |||
def test_iterable_eviction(self): | |||
def test_iterable_eviction(self) -> None: | |||
clock = MockClock() | |||
cache = ExpiringCache("test", clock, max_len=5, iterable=True) | |||
cache: ExpiringCache[str, List[int]] = ExpiringCache( | |||
"test", cast(Clock, clock), max_len=5, iterable=True | |||
) | |||
cache["key"] = [1] | |||
cache["key2"] = [2, 3] | |||
@@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(cache.get("key3"), [4, 5]) | |||
self.assertEqual(cache.get("key4"), [6, 7]) | |||
def test_time_eviction(self): | |||
def test_time_eviction(self) -> None: | |||
clock = MockClock() | |||
cache = ExpiringCache("test", clock, expiry_ms=1000) | |||
cache: ExpiringCache[str, int] = ExpiringCache( | |||
"test", cast(Clock, clock), expiry_ms=1000 | |||
) | |||
cache["key"] = 1 | |||
clock.advance_time(0.5) | |||
@@ -12,22 +12,28 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import threading | |||
from io import StringIO | |||
from io import BytesIO | |||
from typing import BinaryIO, Generator, Optional, cast | |||
from unittest.mock import NonCallableMock | |||
from twisted.internet import defer, reactor | |||
from zope.interface import implementer | |||
from twisted.internet import defer, reactor as _reactor | |||
from twisted.internet.interfaces import IPullProducer | |||
from synapse.types import ISynapseReactor | |||
from synapse.util.file_consumer import BackgroundFileConsumer | |||
from tests import unittest | |||
reactor = cast(ISynapseReactor, _reactor) | |||
class FileConsumerTests(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_pull_consumer(self): | |||
string_file = StringIO() | |||
def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]: | |||
string_file = BytesIO() | |||
consumer = BackgroundFileConsumer(string_file, reactor=reactor) | |||
try: | |||
@@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase): | |||
yield producer.register_with_consumer(consumer) | |||
yield producer.write_and_wait("Foo") | |||
yield producer.write_and_wait(b"Foo") | |||
self.assertEqual(string_file.getvalue(), "Foo") | |||
self.assertEqual(string_file.getvalue(), b"Foo") | |||
yield producer.write_and_wait("Bar") | |||
yield producer.write_and_wait(b"Bar") | |||
self.assertEqual(string_file.getvalue(), "FooBar") | |||
self.assertEqual(string_file.getvalue(), b"FooBar") | |||
finally: | |||
consumer.unregisterProducer() | |||
yield consumer.wait() | |||
yield consumer.wait() # type: ignore[misc] | |||
self.assertTrue(string_file.closed) | |||
@defer.inlineCallbacks | |||
def test_push_consumer(self): | |||
string_file = BlockingStringWrite() | |||
consumer = BackgroundFileConsumer(string_file, reactor=reactor) | |||
def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]: | |||
string_file = BlockingBytesWrite() | |||
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) | |||
try: | |||
producer = NonCallableMock(spec_set=[]) | |||
consumer.registerProducer(producer, True) | |||
consumer.write("Foo") | |||
yield string_file.wait_for_n_writes(1) | |||
consumer.write(b"Foo") | |||
yield string_file.wait_for_n_writes(1) # type: ignore[misc] | |||
self.assertEqual(string_file.buffer, "Foo") | |||
self.assertEqual(string_file.buffer, b"Foo") | |||
consumer.write("Bar") | |||
yield string_file.wait_for_n_writes(2) | |||
consumer.write(b"Bar") | |||
yield string_file.wait_for_n_writes(2) # type: ignore[misc] | |||
self.assertEqual(string_file.buffer, "FooBar") | |||
self.assertEqual(string_file.buffer, b"FooBar") | |||
finally: | |||
consumer.unregisterProducer() | |||
yield consumer.wait() | |||
yield consumer.wait() # type: ignore[misc] | |||
self.assertTrue(string_file.closed) | |||
@defer.inlineCallbacks | |||
def test_push_producer_feedback(self): | |||
string_file = BlockingStringWrite() | |||
consumer = BackgroundFileConsumer(string_file, reactor=reactor) | |||
def test_push_producer_feedback( | |||
self, | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
string_file = BlockingBytesWrite() | |||
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) | |||
try: | |||
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) | |||
resume_deferred = defer.Deferred() | |||
resume_deferred: defer.Deferred = defer.Deferred() | |||
producer.resumeProducing.side_effect = lambda: resume_deferred.callback( | |||
None | |||
) | |||
@@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase): | |||
number_writes = 0 | |||
with string_file.write_lock: | |||
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): | |||
consumer.write("Foo") | |||
consumer.write(b"Foo") | |||
number_writes += 1 | |||
producer.pauseProducing.assert_called_once() | |||
yield string_file.wait_for_n_writes(number_writes) | |||
yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc] | |||
yield resume_deferred | |||
producer.resumeProducing.assert_called_once() | |||
finally: | |||
consumer.unregisterProducer() | |||
yield consumer.wait() | |||
yield consumer.wait() # type: ignore[misc] | |||
self.assertTrue(string_file.closed) | |||
@implementer(IPullProducer) | |||
class DummyPullProducer: | |||
def __init__(self): | |||
self.consumer = None | |||
self.deferred = defer.Deferred() | |||
def __init__(self) -> None: | |||
self.consumer: Optional[BackgroundFileConsumer] = None | |||
self.deferred: "defer.Deferred[object]" = defer.Deferred() | |||
def resumeProducing(self): | |||
def resumeProducing(self) -> None: | |||
d = self.deferred | |||
self.deferred = defer.Deferred() | |||
d.callback(None) | |||
def write_and_wait(self, bytes): | |||
def stopProducing(self) -> None: | |||
raise RuntimeError("Unexpected call") | |||
def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]": | |||
assert self.consumer is not None | |||
d = self.deferred | |||
self.consumer.write(bytes) | |||
self.consumer.write(write_bytes) | |||
return d | |||
def register_with_consumer(self, consumer): | |||
def register_with_consumer( | |||
self, consumer: BackgroundFileConsumer | |||
) -> "defer.Deferred[object]": | |||
d = self.deferred | |||
self.consumer = consumer | |||
self.consumer.registerProducer(self, False) | |||
return d | |||
class BlockingStringWrite: | |||
def __init__(self): | |||
self.buffer = "" | |||
class BlockingBytesWrite: | |||
def __init__(self) -> None: | |||
self.buffer = b"" | |||
self.closed = False | |||
self.write_lock = threading.Lock() | |||
self._notify_write_deferred = None | |||
self._notify_write_deferred: Optional[defer.Deferred] = None | |||
self._number_of_writes = 0 | |||
def write(self, bytes): | |||
def write(self, write_bytes: bytes) -> None: | |||
with self.write_lock: | |||
self.buffer += bytes | |||
self.buffer += write_bytes | |||
self._number_of_writes += 1 | |||
reactor.callFromThread(self._notify_write) | |||
def close(self): | |||
def close(self) -> None: | |||
self.closed = True | |||
def _notify_write(self): | |||
def _notify_write(self) -> None: | |||
"Called by write to indicate a write happened" | |||
with self.write_lock: | |||
if not self._notify_write_deferred: | |||
@@ -161,7 +176,9 @@ class BlockingStringWrite: | |||
d.callback(None) | |||
@defer.inlineCallbacks | |||
def wait_for_n_writes(self, n): | |||
def wait_for_n_writes( | |||
self, n: int | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
"Wait for n writes to have happened" | |||
while True: | |||
with self.write_lock: | |||
@@ -19,7 +19,7 @@ from tests.unittest import TestCase | |||
class ChunkSeqTests(TestCase): | |||
def test_short_seq(self): | |||
def test_short_seq(self) -> None: | |||
parts = chunk_seq("123", 8) | |||
self.assertEqual( | |||
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase): | |||
["123"], | |||
) | |||
def test_long_seq(self): | |||
def test_long_seq(self) -> None: | |||
parts = chunk_seq("abcdefghijklmnop", 8) | |||
self.assertEqual( | |||
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase): | |||
["abcdefgh", "ijklmnop"], | |||
) | |||
def test_uneven_parts(self): | |||
def test_uneven_parts(self) -> None: | |||
parts = chunk_seq("abcdefghijklmnop", 5) | |||
self.assertEqual( | |||
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase): | |||
["abcde", "fghij", "klmno", "p"], | |||
) | |||
def test_empty_input(self): | |||
def test_empty_input(self) -> None: | |||
parts: Iterable[Sequence] = chunk_seq([], 5) | |||
self.assertEqual( | |||
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase): | |||
class SortTopologically(TestCase): | |||
def test_empty(self): | |||
def test_empty(self) -> None: | |||
"Test that an empty graph works correctly" | |||
graph: Dict[int, List[int]] = {} | |||
self.assertEqual(list(sorted_topologically([], graph)), []) | |||
def test_handle_empty_graph(self): | |||
def test_handle_empty_graph(self) -> None: | |||
"Test that a graph where a node doesn't have an entry is treated as empty" | |||
graph: Dict[int, List[int]] = {} | |||
@@ -67,7 +67,7 @@ class SortTopologically(TestCase): | |||
# For disconnected nodes the output is simply sorted. | |||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) | |||
def test_disconnected(self): | |||
def test_disconnected(self) -> None: | |||
"Test that a graph with no edges work" | |||
graph: Dict[int, List[int]] = {1: [], 2: []} | |||
@@ -75,20 +75,20 @@ class SortTopologically(TestCase): | |||
# For disconnected nodes the output is simply sorted. | |||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) | |||
def test_linear(self): | |||
def test_linear(self) -> None: | |||
"Test that a simple `4 -> 3 -> 2 -> 1` graph works" | |||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} | |||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |||
def test_subset(self): | |||
def test_subset(self) -> None: | |||
"Test that only sorting a subset of the graph works" | |||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} | |||
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) | |||
def test_fork(self): | |||
def test_fork(self) -> None: | |||
"Test that a forked graph works" | |||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} | |||
@@ -96,13 +96,13 @@ class SortTopologically(TestCase): | |||
# always get the same one. | |||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |||
def test_duplicates(self): | |||
def test_duplicates(self) -> None: | |||
"Test that a graph with duplicate edges work" | |||
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} | |||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) | |||
def test_multiple_paths(self): | |||
def test_multiple_paths(self) -> None: | |||
"Test that a graph with multiple paths between two nodes work" | |||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} | |||
@@ -1,5 +1,21 @@ | |||
# Copyright 2014-2022 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Callable, Generator, cast | |||
import twisted.python.failure | |||
from twisted.internet import defer, reactor | |||
from twisted.internet import defer, reactor as _reactor | |||
from synapse.logging.context import ( | |||
SENTINEL_CONTEXT, | |||
@@ -10,25 +26,30 @@ from synapse.logging.context import ( | |||
nested_logging_context, | |||
run_in_background, | |||
) | |||
from synapse.types import ISynapseReactor | |||
from synapse.util import Clock | |||
from .. import unittest | |||
reactor = cast(ISynapseReactor, _reactor) | |||
class LoggingContextTestCase(unittest.TestCase): | |||
def _check_test_key(self, value): | |||
self.assertEqual(current_context().name, value) | |||
def _check_test_key(self, value: str) -> None: | |||
context = current_context() | |||
assert isinstance(context, LoggingContext) | |||
self.assertEqual(context.name, value) | |||
def test_with_context(self): | |||
def test_with_context(self) -> None: | |||
with LoggingContext("test"): | |||
self._check_test_key("test") | |||
@defer.inlineCallbacks | |||
def test_sleep(self): | |||
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]: | |||
clock = Clock(reactor) | |||
@defer.inlineCallbacks | |||
def competing_callback(): | |||
def competing_callback() -> Generator["defer.Deferred[object]", object, None]: | |||
with LoggingContext("competing"): | |||
yield clock.sleep(0) | |||
self._check_test_key("competing") | |||
@@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase): | |||
yield clock.sleep(0) | |||
self._check_test_key("one") | |||
def _test_run_in_background(self, function): | |||
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred: | |||
sentinel_context = current_context() | |||
callback_completed = [False] | |||
callback_completed = False | |||
with LoggingContext("one"): | |||
# fire off function, but don't wait on it. | |||
d2 = run_in_background(function) | |||
def cb(res): | |||
callback_completed[0] = True | |||
def cb(res: object) -> object: | |||
nonlocal callback_completed | |||
callback_completed = True | |||
return res | |||
d2.addCallback(cb) | |||
@@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase): | |||
# the logcontext is left in a sane state. | |||
d2 = defer.Deferred() | |||
def check_logcontext(): | |||
if not callback_completed[0]: | |||
def check_logcontext() -> None: | |||
if not callback_completed: | |||
reactor.callLater(0.01, check_logcontext) | |||
return | |||
@@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase): | |||
# test is done once d2 finishes | |||
return d2 | |||
def test_run_in_background_with_blocking_fn(self): | |||
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred: | |||
@defer.inlineCallbacks | |||
def blocking_function(): | |||
def blocking_function() -> Generator["defer.Deferred[object]", object, None]: | |||
yield Clock(reactor).sleep(0) | |||
return self._test_run_in_background(blocking_function) | |||
def test_run_in_background_with_non_blocking_fn(self): | |||
def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred: | |||
@defer.inlineCallbacks | |||
def nonblocking_function(): | |||
def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]: | |||
with PreserveLoggingContext(): | |||
yield defer.succeed(None) | |||
return self._test_run_in_background(nonblocking_function) | |||
def test_run_in_background_with_chained_deferred(self): | |||
def test_run_in_background_with_chained_deferred(self) -> defer.Deferred: | |||
# a function which returns a deferred which looks like it has been | |||
# called, but is actually paused | |||
def testfunc(): | |||
def testfunc() -> defer.Deferred: | |||
return make_deferred_yieldable(_chained_deferred_function()) | |||
return self._test_run_in_background(testfunc) | |||
def test_run_in_background_with_coroutine(self): | |||
async def testfunc(): | |||
def test_run_in_background_with_coroutine(self) -> defer.Deferred: | |||
async def testfunc() -> None: | |||
self._check_test_key("one") | |||
d = Clock(reactor).sleep(0) | |||
self.assertIs(current_context(), SENTINEL_CONTEXT) | |||
@@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase): | |||
return self._test_run_in_background(testfunc) | |||
def test_run_in_background_with_nonblocking_coroutine(self): | |||
async def testfunc(): | |||
def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred: | |||
async def testfunc() -> None: | |||
self._check_test_key("one") | |||
return self._test_run_in_background(testfunc) | |||
@defer.inlineCallbacks | |||
def test_make_deferred_yieldable(self): | |||
def test_make_deferred_yieldable( | |||
self, | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
# a function which returns an incomplete deferred, but doesn't follow | |||
# the synapse rules. | |||
def blocking_function(): | |||
d = defer.Deferred() | |||
def blocking_function() -> defer.Deferred: | |||
d: defer.Deferred = defer.Deferred() | |||
reactor.callLater(0, d.callback, None) | |||
return d | |||
@@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase): | |||
self._check_test_key("one") | |||
@defer.inlineCallbacks | |||
def test_make_deferred_yieldable_with_chained_deferreds(self): | |||
def test_make_deferred_yieldable_with_chained_deferreds( | |||
self, | |||
) -> Generator["defer.Deferred[object]", object, None]: | |||
sentinel_context = current_context() | |||
with LoggingContext("one"): | |||
@@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase): | |||
# now it should be restored | |||
self._check_test_key("one") | |||
def test_nested_logging_context(self): | |||
def test_nested_logging_context(self) -> None: | |||
with LoggingContext("foo"): | |||
nested_context = nested_logging_context(suffix="bar") | |||
self.assertEqual(nested_context.name, "foo-bar") | |||
@@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase): | |||
# a function which returns a deferred which has been "called", but | |||
# which had a function which returned another incomplete deferred on | |||
# its callback list, so won't yet call any other new callbacks. | |||
def _chained_deferred_function(): | |||
def _chained_deferred_function() -> defer.Deferred: | |||
d = defer.succeed(None) | |||
def cb(res): | |||
d2 = defer.Deferred() | |||
def cb(res: object) -> defer.Deferred: | |||
d2: defer.Deferred = defer.Deferred() | |||
reactor.callLater(0, d2.callback, res) | |||
return d2 | |||
@@ -23,7 +23,7 @@ class TestException(Exception): | |||
class LogFormatterTestCase(unittest.TestCase): | |||
def test_formatter(self): | |||
def test_formatter(self) -> None: | |||
formatter = LogFormatter() | |||
try: | |||
@@ -13,10 +13,11 @@ | |||
# limitations under the License. | |||
from typing import List | |||
from typing import List, Tuple | |||
from unittest.mock import Mock, patch | |||
from synapse.metrics.jemalloc import JemallocStats | |||
from synapse.types import JsonDict | |||
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries | |||
from synapse.util.caches.treecache import TreeCache | |||
@@ -25,14 +26,14 @@ from tests.unittest import override_config | |||
class LruCacheTestCase(unittest.HomeserverTestCase): | |||
def test_get_set(self): | |||
cache = LruCache(1) | |||
def test_get_set(self) -> None: | |||
cache: LruCache[str, str] = LruCache(1) | |||
cache["key"] = "value" | |||
self.assertEqual(cache.get("key"), "value") | |||
self.assertEqual(cache["key"], "value") | |||
def test_eviction(self): | |||
cache = LruCache(2) | |||
def test_eviction(self) -> None: | |||
cache: LruCache[int, int] = LruCache(2) | |||
cache[1] = 1 | |||
cache[2] = 2 | |||
@@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(cache.get(2), 2) | |||
self.assertEqual(cache.get(3), 3) | |||
def test_setdefault(self): | |||
cache = LruCache(1) | |||
def test_setdefault(self) -> None: | |||
cache: LruCache[str, int] = LruCache(1) | |||
self.assertEqual(cache.setdefault("key", 1), 1) | |||
self.assertEqual(cache.get("key"), 1) | |||
self.assertEqual(cache.setdefault("key", 2), 1) | |||
@@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase): | |||
cache["key"] = 2 # Make sure overriding works. | |||
self.assertEqual(cache.get("key"), 2) | |||
def test_pop(self): | |||
cache = LruCache(1) | |||
def test_pop(self) -> None: | |||
cache: LruCache[str, int] = LruCache(1) | |||
cache["key"] = 1 | |||
self.assertEqual(cache.pop("key"), 1) | |||
self.assertEqual(cache.pop("key"), None) | |||
def test_del_multi(self): | |||
cache = LruCache(4, cache_type=TreeCache) | |||
def test_del_multi(self) -> None: | |||
# The type here isn't quite correct as they don't handle TreeCache well. | |||
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) | |||
cache[("animal", "cat")] = "mew" | |||
cache[("animal", "dog")] = "woof" | |||
cache[("vehicles", "car")] = "vroom" | |||
@@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(cache.get(("animal", "cat")), "mew") | |||
self.assertEqual(cache.get(("vehicles", "car")), "vroom") | |||
cache.del_multi(("animal",)) | |||
cache.del_multi(("animal",)) # type: ignore[arg-type] | |||
self.assertEqual(len(cache), 2) | |||
self.assertEqual(cache.get(("animal", "cat")), None) | |||
self.assertEqual(cache.get(("animal", "dog")), None) | |||
@@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(cache.get(("vehicles", "train")), "chuff") | |||
# Man from del_multi say "Yes". | |||
def test_clear(self): | |||
cache = LruCache(1) | |||
def test_clear(self) -> None: | |||
cache: LruCache[str, int] = LruCache(1) | |||
cache["key"] = 1 | |||
cache.clear() | |||
self.assertEqual(len(cache), 0) | |||
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) | |||
def test_special_size(self): | |||
cache = LruCache(10, "mycache") | |||
def test_special_size(self) -> None: | |||
cache: LruCache = LruCache(10, "mycache") | |||
self.assertEqual(cache.max_size, 100) | |||
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
def test_get(self): | |||
def test_get(self) -> None: | |||
m = Mock() | |||
cache = LruCache(1) | |||
cache: LruCache[str, str] = LruCache(1) | |||
cache.set("key", "value") | |||
self.assertFalse(m.called) | |||
@@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
cache.set("key", "value") | |||
self.assertEqual(m.call_count, 1) | |||
def test_multi_get(self): | |||
def test_multi_get(self) -> None: | |||
m = Mock() | |||
cache = LruCache(1) | |||
cache: LruCache[str, str] = LruCache(1) | |||
cache.set("key", "value") | |||
self.assertFalse(m.called) | |||
@@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
cache.set("key", "value") | |||
self.assertEqual(m.call_count, 1) | |||
def test_set(self): | |||
def test_set(self) -> None: | |||
m = Mock() | |||
cache = LruCache(1) | |||
cache: LruCache[str, str] = LruCache(1) | |||
cache.set("key", "value", callbacks=[m]) | |||
self.assertFalse(m.called) | |||
@@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
cache.set("key", "value") | |||
self.assertEqual(m.call_count, 1) | |||
def test_pop(self): | |||
def test_pop(self) -> None: | |||
m = Mock() | |||
cache = LruCache(1) | |||
cache: LruCache[str, str] = LruCache(1) | |||
cache.set("key", "value", callbacks=[m]) | |||
self.assertFalse(m.called) | |||
@@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
cache.pop("key") | |||
self.assertEqual(m.call_count, 1) | |||
def test_del_multi(self): | |||
def test_del_multi(self) -> None: | |||
m1 = Mock() | |||
m2 = Mock() | |||
m3 = Mock() | |||
m4 = Mock() | |||
cache = LruCache(4, cache_type=TreeCache) | |||
# The type here isn't quite correct as they don't handle TreeCache well. | |||
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) | |||
cache.set(("a", "1"), "value", callbacks=[m1]) | |||
cache.set(("a", "2"), "value", callbacks=[m2]) | |||
@@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(m3.call_count, 0) | |||
self.assertEqual(m4.call_count, 0) | |||
cache.del_multi(("a",)) | |||
cache.del_multi(("a",)) # type: ignore[arg-type] | |||
self.assertEqual(m1.call_count, 1) | |||
self.assertEqual(m2.call_count, 1) | |||
self.assertEqual(m3.call_count, 0) | |||
self.assertEqual(m4.call_count, 0) | |||
def test_clear(self): | |||
def test_clear(self) -> None: | |||
m1 = Mock() | |||
m2 = Mock() | |||
cache = LruCache(5) | |||
cache: LruCache[str, str] = LruCache(5) | |||
cache.set("key1", "value", callbacks=[m1]) | |||
cache.set("key2", "value", callbacks=[m2]) | |||
@@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(m1.call_count, 1) | |||
self.assertEqual(m2.call_count, 1) | |||
def test_eviction(self): | |||
def test_eviction(self) -> None: | |||
m1 = Mock(name="m1") | |||
m2 = Mock(name="m2") | |||
m3 = Mock(name="m3") | |||
cache = LruCache(2) | |||
cache: LruCache[str, str] = LruCache(2) | |||
cache.set("key1", "value", callbacks=[m1]) | |||
cache.set("key2", "value", callbacks=[m2]) | |||
@@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): | |||
class LruCacheSizedTestCase(unittest.HomeserverTestCase): | |||
def test_evict(self): | |||
cache = LruCache(5, size_callback=len) | |||
def test_evict(self) -> None: | |||
cache: LruCache[str, List[int]] = LruCache(5, size_callback=len) | |||
cache["key1"] = [0] | |||
cache["key2"] = [1, 2] | |||
cache["key3"] = [3] | |||
@@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): | |||
cache["key1"] = [] | |||
self.assertEqual(len(cache), 0) | |||
assert isinstance(cache.cache, dict) | |||
cache.cache["key1"].drop_from_cache() | |||
self.assertIsNone( | |||
cache.pop("key1"), "Cache entry should have been evicted but wasn't" | |||
@@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): | |||
class TimeEvictionTestCase(unittest.HomeserverTestCase): | |||
"""Test that time based eviction works correctly.""" | |||
def default_config(self): | |||
def default_config(self) -> JsonDict: | |||
config = super().default_config() | |||
config.setdefault("caches", {})["expiry_time"] = "30m" | |||
return config | |||
def test_evict(self): | |||
def test_evict(self) -> None: | |||
setup_expire_lru_cache_entries(self.hs) | |||
cache = LruCache(5, clock=self.hs.get_clock()) | |||
cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock()) | |||
# Check that we evict entries we haven't accessed for 30 minutes. | |||
cache["key1"] = 1 | |||
@@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): | |||
} | |||
) | |||
@patch("synapse.util.caches.lrucache.get_jemalloc_stats") | |||
def test_evict_memory(self, jemalloc_interface) -> None: | |||
def test_evict_memory(self, jemalloc_interface: Mock) -> None: | |||
mock_jemalloc_class = Mock(spec=JemallocStats) | |||
jemalloc_interface.return_value = mock_jemalloc_class | |||
@@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): | |||
mock_jemalloc_class.get_stat.return_value = 924288000 | |||
setup_expire_lru_cache_entries(self.hs) | |||
cache = LruCache(4, clock=self.hs.get_clock()) | |||
cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock()) | |||
cache["key1"] = 1 | |||
cache["key2"] = 2 | |||
@@ -21,14 +21,14 @@ from tests.unittest import TestCase | |||
class MacaroonGeneratorTestCase(TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.reactor, hs_clock = get_clock() | |||
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret") | |||
self.other_macaroon_generator = MacaroonGenerator( | |||
hs_clock, "tesths", b"anothersecretkey" | |||
) | |||
def test_guest_access_token(self): | |||
def test_guest_access_token(self) -> None: | |||
"""Test the generation and verification of guest access tokens""" | |||
token = self.macaroon_generator.generate_guest_access_token("@user:tesths") | |||
user_id = self.macaroon_generator.verify_guest_token(token) | |||
@@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase): | |||
with self.assertRaises(MacaroonVerificationFailedException): | |||
self.macaroon_generator.verify_guest_token(token) | |||
def test_delete_pusher_token(self): | |||
def test_delete_pusher_token(self) -> None: | |||
"""Test the generation and verification of delete_pusher tokens""" | |||
token = self.macaroon_generator.generate_delete_pusher_token( | |||
"@user:tesths", "m.mail", "john@example.com" | |||
@@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase): | |||
) | |||
self.assertEqual(user_id, "@user:tesths") | |||
def test_oidc_session_token(self): | |||
def test_oidc_session_token(self) -> None: | |||
"""Test the generation and verification of OIDC session cookies""" | |||
state = "arandomstate" | |||
session_data = OidcSessionData( | |||
@@ -13,16 +13,19 @@ | |||
# limitations under the License. | |||
from typing import Optional | |||
from twisted.internet.defer import Deferred | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.config.ratelimiting import FederationRatelimitSettings | |||
from synapse.util.ratelimitutils import FederationRateLimiter | |||
from tests.server import get_clock | |||
from tests.server import ThreadedMemoryReactorClock, get_clock | |||
from tests.unittest import TestCase | |||
from tests.utils import default_config | |||
class FederationRateLimiterTestCase(TestCase): | |||
def test_ratelimit(self): | |||
def test_ratelimit(self) -> None: | |||
"""A simple test with the default values""" | |||
reactor, clock = get_clock() | |||
rc_config = build_rc_config() | |||
@@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase): | |||
# shouldn't block | |||
self.successResultOf(d1) | |||
def test_concurrent_limit(self): | |||
def test_concurrent_limit(self) -> None: | |||
"""Test what happens when we hit the concurrent limit""" | |||
reactor, clock = get_clock() | |||
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) | |||
@@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase): | |||
cm2.__exit__(None, None, None) | |||
self.successResultOf(d3) | |||
def test_sleep_limit(self): | |||
def test_sleep_limit(self) -> None: | |||
"""Test what happens when we hit the sleep limit""" | |||
reactor, clock = get_clock() | |||
rc_config = build_rc_config( | |||
@@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase): | |||
self.assertAlmostEqual(sleep_time, 500, places=3) | |||
def _await_resolution(reactor, d): | |||
def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: | |||
"""advance the clock until the deferred completes. | |||
Returns the number of milliseconds it took to complete. | |||
@@ -90,7 +93,7 @@ def _await_resolution(reactor, d): | |||
return (reactor.seconds() - start_time) * 1000 | |||
def build_rc_config(settings: Optional[dict] = None): | |||
def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings: | |||
config_dict = default_config("test") | |||
config_dict.update(settings or {}) | |||
config = HomeServerConfig() | |||
@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase | |||
class RetryLimiterTestCase(HomeserverTestCase): | |||
def test_new_destination(self): | |||
def test_new_destination(self) -> None: | |||
"""A happy-path case with a new destination and a successful operation""" | |||
store = self.hs.get_datastores().main | |||
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) | |||
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase): | |||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) | |||
self.assertIsNone(new_timings) | |||
def test_limiter(self): | |||
def test_limiter(self) -> None: | |||
"""General test case which walks through the process of a failing request""" | |||
store = self.hs.get_datastores().main | |||
@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
acquired_d: "Deferred[None]" = Deferred() | |||
unblock_d: "Deferred[None]" = Deferred() | |||
async def reader_or_writer(): | |||
async def reader_or_writer() -> str: | |||
async with read_or_write(key): | |||
acquired_d.callback(None) | |||
await unblock_d | |||
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
d.called, msg="deferred %d was unexpectedly resolved" % (i + n) | |||
) | |||
def test_rwlock(self): | |||
def test_rwlock(self) -> None: | |||
rwlock = ReadWriteLock() | |||
key = "key" | |||
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") | |||
self.assertTrue(acquired_d.called) | |||
def test_lock_handoff_to_nonblocking_writer(self): | |||
def test_lock_handoff_to_nonblocking_writer(self) -> None: | |||
"""Test a writer handing the lock to another writer that completes instantly.""" | |||
rwlock = ReadWriteLock() | |||
key = "key" | |||
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") | |||
self.assertTrue(d3.called) | |||
def test_cancellation_while_holding_read_lock(self): | |||
def test_cancellation_while_holding_read_lock(self) -> None: | |||
"""Test cancellation while holding a read lock. | |||
A waiting writer should be given the lock when the reader holding the lock is | |||
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
) | |||
self.assertEqual("write completed", self.successResultOf(writer_d)) | |||
def test_cancellation_while_holding_write_lock(self): | |||
def test_cancellation_while_holding_write_lock(self) -> None: | |||
"""Test cancellation while holding a write lock. | |||
A waiting reader should be given the lock when the writer holding the lock is | |||
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
) | |||
self.assertEqual("read completed", self.successResultOf(reader_d)) | |||
def test_cancellation_while_waiting_for_read_lock(self): | |||
def test_cancellation_while_waiting_for_read_lock(self) -> None: | |||
"""Test cancellation while waiting for a read lock. | |||
Tests that cancelling a waiting reader: | |||
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase): | |||
) | |||
self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) | |||
def test_cancellation_while_waiting_for_write_lock(self): | |||
def test_cancellation_while_waiting_for_write_lock(self) -> None: | |||
"""Test cancellation while waiting for a write lock. | |||
Tests that cancelling a waiting writer: | |||
@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
Tests for StreamChangeCache. | |||
""" | |||
def test_prefilled_cache(self): | |||
def test_prefilled_cache(self) -> None: | |||
""" | |||
Providing a prefilled cache to StreamChangeCache will result in a cache | |||
with the prefilled-cache entered in. | |||
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) | |||
self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) | |||
def test_has_entity_changed(self): | |||
def test_has_entity_changed(self) -> None: | |||
""" | |||
StreamChangeCache.entity_has_changed will mark entities as changed, and | |||
has_entity_changed will observe the changed entities. | |||
@@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) | |||
self.assertTrue(cache.has_entity_changed("not@here.website", 0)) | |||
def test_entity_has_changed_pops_off_start(self): | |||
def test_entity_has_changed_pops_off_start(self) -> None: | |||
""" | |||
StreamChangeCache.entity_has_changed will respect the max size and | |||
purge the oldest items upon reaching that max size. | |||
@@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
) | |||
self.assertIsNone(cache.get_all_entities_changed(1)) | |||
def test_get_all_entities_changed(self): | |||
def test_get_all_entities_changed(self) -> None: | |||
""" | |||
StreamChangeCache.get_all_entities_changed will return all changed | |||
entities since the given position. If the position is before the start | |||
@@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
r = cache.get_all_entities_changed(3) | |||
self.assertTrue(r == ok1 or r == ok2) | |||
def test_has_any_entity_changed(self): | |||
def test_has_any_entity_changed(self) -> None: | |||
""" | |||
StreamChangeCache.has_any_entity_changed will return True if any | |||
entities have been changed since the provided stream position, and | |||
@@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
self.assertFalse(cache.has_any_entity_changed(2)) | |||
self.assertFalse(cache.has_any_entity_changed(3)) | |||
def test_get_entities_changed(self): | |||
def test_get_entities_changed(self) -> None: | |||
""" | |||
StreamChangeCache.get_entities_changed will return the entities in the | |||
given list that have changed since the provided stream ID. If the | |||
@@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): | |||
{"bar@baz.net"}, | |||
) | |||
def test_max_pos(self): | |||
def test_max_pos(self) -> None: | |||
""" | |||
StreamChangeCache.get_max_pos_of_last_change will return the most | |||
recent point where the entity could have changed. If the entity is not | |||
@@ -19,7 +19,7 @@ from .. import unittest | |||
class StringUtilsTestCase(unittest.TestCase): | |||
def test_client_secret_regex(self): | |||
def test_client_secret_regex(self) -> None: | |||
"""Ensure that client_secret does not contain illegal characters""" | |||
good = [ | |||
"abcde12345", | |||
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase): | |||
with self.assertRaises(SynapseError): | |||
assert_valid_client_secret(client_secret) | |||
def test_base62_encode(self): | |||
def test_base62_encode(self) -> None: | |||
self.assertEqual("0", base62_encode(0)) | |||
self.assertEqual("10", base62_encode(62)) | |||
self.assertEqual("1c", base62_encode(100)) | |||
@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase | |||
class CanonicaliseEmailTests(HomeserverTestCase): | |||
def test_no_at(self): | |||
def test_no_at(self) -> None: | |||
with self.assertRaises(ValueError): | |||
canonicalise_email("address-without-at.bar") | |||
def test_two_at(self): | |||
def test_two_at(self) -> None: | |||
with self.assertRaises(ValueError): | |||
canonicalise_email("foo@foo@test.bar") | |||
def test_bad_format(self): | |||
def test_bad_format(self) -> None: | |||
with self.assertRaises(ValueError): | |||
canonicalise_email("user@bad.example.net@good.example.com") | |||
def test_valid_format(self): | |||
def test_valid_format(self) -> None: | |||
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") | |||
def test_domain_to_lower(self): | |||
def test_domain_to_lower(self) -> None: | |||
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") | |||
def test_domain_with_umlaut(self): | |||
def test_domain_with_umlaut(self) -> None: | |||
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") | |||
def test_address_casefold(self): | |||
def test_address_casefold(self) -> None: | |||
self.assertEqual( | |||
canonicalise_email("Strauß@Example.com"), "strauss@example.com" | |||
) | |||
def test_address_trim(self): | |||
def test_address_trim(self) -> None: | |||
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") |
@@ -19,7 +19,7 @@ from .. import unittest | |||
class TreeCacheTestCase(unittest.TestCase): | |||
def test_get_set_onelevel(self): | |||
def test_get_set_onelevel(self) -> None: | |||
cache = TreeCache() | |||
cache[("a",)] = "A" | |||
cache[("b",)] = "B" | |||
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase): | |||
self.assertEqual(cache.get(("b",)), "B") | |||
self.assertEqual(len(cache), 2) | |||
def test_pop_onelevel(self): | |||
def test_pop_onelevel(self) -> None: | |||
cache = TreeCache() | |||
cache[("a",)] = "A" | |||
cache[("b",)] = "B" | |||
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase): | |||
self.assertEqual(cache.get(("b",)), "B") | |||
self.assertEqual(len(cache), 1) | |||
def test_get_set_twolevel(self): | |||
def test_get_set_twolevel(self) -> None: | |||
cache = TreeCache() | |||
cache[("a", "a")] = "AA" | |||
cache[("a", "b")] = "AB" | |||
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase): | |||
self.assertEqual(cache.get(("b", "a")), "BA") | |||
self.assertEqual(len(cache), 3) | |||
def test_pop_twolevel(self): | |||
def test_pop_twolevel(self) -> None: | |||
cache = TreeCache() | |||
cache[("a", "a")] = "AA" | |||
cache[("a", "b")] = "AB" | |||
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase): | |||
self.assertEqual(cache.pop(("b", "a")), None) | |||
self.assertEqual(len(cache), 1) | |||
def test_pop_mixedlevel(self): | |||
def test_pop_mixedlevel(self) -> None: | |||
cache = TreeCache() | |||
cache[("a", "a")] = "AA" | |||
cache[("a", "b")] = "AB" | |||
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase): | |||
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) | |||
def test_clear(self): | |||
def test_clear(self) -> None: | |||
cache = TreeCache() | |||
cache[("a",)] = "A" | |||
cache[("b",)] = "B" | |||
cache.clear() | |||
self.assertEqual(len(cache), 0) | |||
def test_contains(self): | |||
def test_contains(self) -> None: | |||
cache = TreeCache() | |||
cache[("a",)] = "A" | |||
self.assertTrue(("a",) in cache) | |||
@@ -18,8 +18,8 @@ from .. import unittest | |||
class WheelTimerTestCase(unittest.TestCase): | |||
def test_single_insert_fetch(self): | |||
wheel = WheelTimer(bucket_size=5) | |||
def test_single_insert_fetch(self) -> None: | |||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) | |||
obj = object() | |||
wheel.insert(100, obj, 150) | |||
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase): | |||
self.assertListEqual(wheel.fetch(156), [obj]) | |||
self.assertListEqual(wheel.fetch(170), []) | |||
def test_multi_insert(self): | |||
wheel = WheelTimer(bucket_size=5) | |||
def test_multi_insert(self) -> None: | |||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) | |||
obj1 = object() | |||
obj2 = object() | |||
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase): | |||
self.assertListEqual(wheel.fetch(200), [obj3]) | |||
self.assertListEqual(wheel.fetch(210), []) | |||
def test_insert_past(self): | |||
wheel = WheelTimer(bucket_size=5) | |||
def test_insert_past(self) -> None: | |||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) | |||
obj = object() | |||
wheel.insert(100, obj, 50) | |||
self.assertListEqual(wheel.fetch(120), [obj]) | |||
def test_insert_past_multi(self): | |||
wheel = WheelTimer(bucket_size=5) | |||
def test_insert_past_multi(self) -> None: | |||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) | |||
obj1 = object() | |||
obj2 = object() | |||