@@ -0,0 +1 @@ | |||
Add missing type hints. |
@@ -59,11 +59,6 @@ exclude = (?x) | |||
|tests/server_notices/test_resource_limits_server_notices.py | |||
|tests/test_state.py | |||
|tests/test_terms_auth.py | |||
|tests/util/caches/test_cached_call.py | |||
|tests/util/caches/test_deferred_cache.py | |||
|tests/util/caches/test_descriptors.py | |||
|tests/util/caches/test_response_cache.py | |||
|tests/util/caches/test_ttlcache.py | |||
|tests/util/test_async_helpers.py | |||
|tests/util/test_batching_queue.py | |||
|tests/util/test_dict_cache.py | |||
@@ -133,6 +128,12 @@ disallow_untyped_defs = True | |||
[mypy-tests.federation.transport.test_client] | |||
disallow_untyped_defs = True | |||
[mypy-tests.util.caches.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.util.caches.test_descriptors] | |||
disallow_untyped_defs = False | |||
[mypy-tests.utils] | |||
disallow_untyped_defs = True | |||
@@ -11,6 +11,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import NoReturn | |||
from unittest.mock import Mock | |||
from twisted.internet import defer | |||
@@ -23,14 +24,14 @@ from tests.unittest import TestCase | |||
class CachedCallTestCase(TestCase): | |||
def test_get(self): | |||
def test_get(self) -> None: | |||
""" | |||
Happy-path test case: makes a couple of calls and makes sure they behave | |||
correctly | |||
""" | |||
d = Deferred() | |||
d: "Deferred[int]" = Deferred() | |||
async def f(): | |||
async def f() -> int: | |||
return await d | |||
slow_call = Mock(side_effect=f) | |||
@@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase): | |||
# now fire off a couple of calls | |||
completed_results = [] | |||
async def r(): | |||
async def r() -> None: | |||
res = await cached_call.get() | |||
completed_results.append(res) | |||
@@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase): | |||
self.assertEqual(r3, 123) | |||
slow_call.assert_not_called() | |||
def test_fast_call(self): | |||
def test_fast_call(self) -> None: | |||
""" | |||
Test the behaviour when the underlying function completes immediately | |||
""" | |||
async def f(): | |||
async def f() -> int: | |||
return 12 | |||
fast_call = Mock(side_effect=f) | |||
@@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase): | |||
class RetryOnExceptionCachedCallTestCase(TestCase): | |||
def test_get(self): | |||
def test_get(self) -> None: | |||
# set up the RetryOnExceptionCachedCall around a function which will fail | |||
# (after a while) | |||
d = Deferred() | |||
d: "Deferred[int]" = Deferred() | |||
async def f1(): | |||
async def f1() -> NoReturn: | |||
await d | |||
raise ValueError("moo") | |||
@@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): | |||
# now fire off a couple of calls | |||
completed_results = [] | |||
async def r(): | |||
async def r() -> None: | |||
try: | |||
await cached_call.get() | |||
except Exception as e1: | |||
@@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): | |||
# to the getter | |||
d = Deferred() | |||
async def f2(): | |||
async def f2() -> int: | |||
return await d | |||
slow_call.reset_mock() | |||
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
from functools import partial | |||
from typing import List, Tuple | |||
from twisted.internet import defer | |||
@@ -22,20 +23,20 @@ from tests.unittest import TestCase | |||
class DeferredCacheTestCase(TestCase): | |||
def test_empty(self): | |||
cache = DeferredCache("test") | |||
def test_empty(self) -> None: | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
with self.assertRaises(KeyError): | |||
cache.get("foo") | |||
def test_hit(self): | |||
cache = DeferredCache("test") | |||
def test_hit(self) -> None: | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
cache.prefill("foo", 123) | |||
self.assertEqual(self.successResultOf(cache.get("foo")), 123) | |||
def test_hit_deferred(self): | |||
cache = DeferredCache("test") | |||
origin_d = defer.Deferred() | |||
def test_hit_deferred(self) -> None: | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
origin_d: "defer.Deferred[int]" = defer.Deferred() | |||
set_d = cache.set("k1", origin_d) | |||
# get should return an incomplete deferred | |||
@@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase): | |||
self.assertFalse(get_d.called) | |||
# add a callback that will make sure that the set_d gets called before the get_d | |||
def check1(r): | |||
def check1(r: str) -> str: | |||
self.assertTrue(set_d.called) | |||
return r | |||
@@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase): | |||
self.assertEqual(self.successResultOf(set_d), 99) | |||
self.assertEqual(self.successResultOf(get_d), 99) | |||
def test_callbacks(self): | |||
def test_callbacks(self) -> None: | |||
"""Invalidation callbacks are called at the right time""" | |||
cache = DeferredCache("test") | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
callbacks = set() | |||
# start with an entry, with a callback | |||
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) | |||
# now replace that entry with a pending result | |||
origin_d = defer.Deferred() | |||
origin_d: "defer.Deferred[int]" = defer.Deferred() | |||
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) | |||
# ... and also make a get request | |||
@@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase): | |||
cache.prefill("k1", 30) | |||
self.assertEqual(callbacks, {"set", "get"}) | |||
def test_set_fail(self): | |||
cache = DeferredCache("test") | |||
def test_set_fail(self) -> None: | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
callbacks = set() | |||
# start with an entry, with a callback | |||
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) | |||
# now replace that entry with a pending result | |||
origin_d = defer.Deferred() | |||
origin_d: defer.Deferred = defer.Deferred() | |||
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) | |||
# ... and also make a get request | |||
@@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase): | |||
cache.prefill("k1", 30) | |||
self.assertEqual(callbacks, {"prefill", "get2"}) | |||
def test_get_immediate(self): | |||
cache = DeferredCache("test") | |||
d1 = defer.Deferred() | |||
def test_get_immediate(self) -> None: | |||
cache: DeferredCache[str, int] = DeferredCache("test") | |||
d1: "defer.Deferred[int]" = defer.Deferred() | |||
cache.set("key1", d1) | |||
# get_immediate should return default | |||
@@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase): | |||
v = cache.get_immediate("key1", 1) | |||
self.assertEqual(v, 2) | |||
def test_invalidate(self): | |||
cache = DeferredCache("test") | |||
def test_invalidate(self) -> None: | |||
cache: DeferredCache[Tuple[str], int] = DeferredCache("test") | |||
cache.prefill(("foo",), 123) | |||
cache.invalidate(("foo",)) | |||
with self.assertRaises(KeyError): | |||
cache.get(("foo",)) | |||
def test_invalidate_all(self): | |||
cache = DeferredCache("testcache") | |||
def test_invalidate_all(self) -> None: | |||
cache: DeferredCache[str, str] = DeferredCache("testcache") | |||
callback_record = [False, False] | |||
def record_callback(idx): | |||
def record_callback(idx: int) -> None: | |||
callback_record[idx] = True | |||
# add a couple of pending entries | |||
d1 = defer.Deferred() | |||
d1: "defer.Deferred[str]" = defer.Deferred() | |||
cache.set("key1", d1, partial(record_callback, 0)) | |||
d2 = defer.Deferred() | |||
d2: "defer.Deferred[str]" = defer.Deferred() | |||
cache.set("key2", d2, partial(record_callback, 1)) | |||
# lookup should return pending deferreds | |||
@@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase): | |||
with self.assertRaises(KeyError): | |||
cache.get("key1", None) | |||
def test_eviction(self): | |||
cache = DeferredCache( | |||
def test_eviction(self) -> None: | |||
cache: DeferredCache[int, str] = DeferredCache( | |||
"test", max_entries=2, apply_cache_factor_from_config=False | |||
) | |||
@@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase): | |||
cache.get(2) | |||
cache.get(3) | |||
def test_eviction_lru(self): | |||
cache = DeferredCache( | |||
def test_eviction_lru(self) -> None: | |||
cache: DeferredCache[int, str] = DeferredCache( | |||
"test", max_entries=2, apply_cache_factor_from_config=False | |||
) | |||
@@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase): | |||
cache.get(1) | |||
cache.get(3) | |||
def test_eviction_iterable(self): | |||
cache = DeferredCache( | |||
def test_eviction_iterable(self) -> None: | |||
cache: DeferredCache[int, List[str]] = DeferredCache( | |||
"test", | |||
max_entries=3, | |||
apply_cache_factor_from_config=False, | |||
@@ -13,11 +13,12 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Iterable, Set, Tuple | |||
from typing import Iterable, Set, Tuple, cast | |||
from unittest import mock | |||
from twisted.internet import defer, reactor | |||
from twisted.internet.defer import CancelledError, Deferred | |||
from twisted.internet.interfaces import IReactorTime | |||
from synapse.api.errors import SynapseError | |||
from synapse.logging.context import ( | |||
@@ -37,8 +38,8 @@ logger = logging.getLogger(__name__) | |||
def run_on_reactor(): | |||
d = defer.Deferred() | |||
reactor.callLater(0, d.callback, 0) | |||
d: "Deferred[int]" = defer.Deferred() | |||
cast(IReactorTime, reactor).callLater(0, d.callback, 0) | |||
return make_deferred_yieldable(d) | |||
@@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase): | |||
callbacks: Set[str] = set() | |||
# set off an asynchronous request | |||
obj.result = origin_d = defer.Deferred() | |||
origin_d: Deferred = defer.Deferred() | |||
obj.result = origin_d | |||
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) | |||
self.assertFalse(d1.called) | |||
@@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase): | |||
"""Check that logcontexts are set and restored correctly when | |||
using the cache.""" | |||
complete_lookup = defer.Deferred() | |||
complete_lookup: Deferred = defer.Deferred() | |||
class Cls: | |||
@descriptors.cached() | |||
@@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
@descriptors.cachedList(cached_method_name="fn", list_name="args1") | |||
async def list_fn(self, args1, arg2): | |||
assert current_context().name == "c1" | |||
context = current_context() | |||
assert isinstance(context, LoggingContext) | |||
assert context.name == "c1" | |||
# we want this to behave like an asynchronous function | |||
await run_on_reactor() | |||
assert current_context().name == "c1" | |||
context = current_context() | |||
assert isinstance(context, LoggingContext) | |||
assert context.name == "c1" | |||
return self.mock(args1, arg2) | |||
with LoggingContext("c1") as c1: | |||
@@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
return self.mock(args1) | |||
obj = Cls() | |||
deferred_result = Deferred() | |||
deferred_result: "Deferred[dict]" = Deferred() | |||
obj.mock.return_value = deferred_result | |||
# start off several concurrent lookups of the same key | |||
@@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase): | |||
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) | |||
""" | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.reactor, self.clock = get_clock() | |||
def with_cache(self, name: str, ms: int = 0) -> ResponseCache: | |||
@@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase): | |||
await self.clock.sleep(1) | |||
return o | |||
def test_cache_hit(self): | |||
def test_cache_hit(self) -> None: | |||
cache = self.with_cache("keeping_cache", ms=9001) | |||
expected_result = "howdy" | |||
@@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase): | |||
"cache should still have the result", | |||
) | |||
def test_cache_miss(self): | |||
def test_cache_miss(self) -> None: | |||
cache = self.with_cache("trashing_cache", ms=0) | |||
expected_result = "howdy" | |||
@@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase): | |||
) | |||
self.assertCountEqual([], cache.keys(), "cache should not have the result now") | |||
def test_cache_expire(self): | |||
def test_cache_expire(self) -> None: | |||
cache = self.with_cache("short_cache", ms=1000) | |||
expected_result = "howdy" | |||
@@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase): | |||
self.reactor.pump((2,)) | |||
self.assertCountEqual([], cache.keys(), "cache should not have the result now") | |||
def test_cache_wait_hit(self): | |||
def test_cache_wait_hit(self) -> None: | |||
cache = self.with_cache("neutral_cache") | |||
expected_result = "howdy" | |||
@@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase): | |||
self.assertEqual(expected_result, self.successResultOf(wrap_d)) | |||
def test_cache_wait_expire(self): | |||
def test_cache_wait_expire(self) -> None: | |||
cache = self.with_cache("medium_cache", ms=3000) | |||
expected_result = "howdy" | |||
@@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase): | |||
self.assertCountEqual([], cache.keys(), "cache should not have the result now") | |||
@parameterized.expand([(True,), (False,)]) | |||
def test_cache_context_nocache(self, should_cache: bool): | |||
def test_cache_context_nocache(self, should_cache: bool) -> None: | |||
"""If the callback clears the should_cache bit, the result should not be cached""" | |||
cache = self.with_cache("medium_cache", ms=3000) | |||
@@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase): | |||
call_count = 0 | |||
async def non_caching(o: str, cache_context: ResponseCacheContext[int]): | |||
async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str: | |||
nonlocal call_count | |||
call_count += 1 | |||
await self.clock.sleep(1) | |||
@@ -20,11 +20,11 @@ from tests import unittest | |||
class CacheTestCase(unittest.TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.mock_timer = Mock(side_effect=lambda: 100.0) | |||
self.cache = TTLCache("test_cache", self.mock_timer) | |||
self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer) | |||
def test_get(self): | |||
def test_get(self) -> None: | |||
"""simple set/get tests""" | |||
self.cache.set("one", "1", 10) | |||
self.cache.set("two", "2", 20) | |||
@@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase): | |||
self.assertEqual(self.cache._metrics.hits, 4) | |||
self.assertEqual(self.cache._metrics.misses, 5) | |||
def test_expiry(self): | |||
def test_expiry(self) -> None: | |||
self.cache.set("one", "1", 10) | |||
self.cache.set("two", "2", 20) | |||
self.cache.set("three", "3", 30) | |||