|
|
@@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, ensureDeferred |
|
|
|
from twisted.test.proto_helpers import MemoryReactor |
|
|
|
|
|
|
|
from synapse.api.constants import EventTypes |
|
|
|
from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP |
|
|
|
from synapse.storage.state import StateFilter |
|
|
|
from synapse.types import StateMap |
|
|
|
from synapse.util import Clock |
|
|
@@ -281,3 +282,71 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): |
|
|
|
|
|
|
|
self.assertEqual(self.get_success(req1), FAKE_STATE) |
|
|
|
self.assertEqual(self.get_success(req2), FAKE_STATE) |
|
|
|
|
|
|
|
def test_inflight_requests_capped(self) -> None: |
|
|
|
""" |
|
|
|
Tests that the number of in-flight requests is capped to 5. |
|
|
|
|
|
|
|
- requests several pieces of state separately |
|
|
|
(5 to hit the limit, 1 to 'shunt out', another that comes after the |
|
|
|
group has been 'shunted out') |
|
|
|
- checks to see that the torrent of requests is shunted out by |
|
|
|
rewriting one of the filters as the 'all' state filter |
|
|
|
- requests after that one do not cause any additional queries |
|
|
|
""" |
|
|
|
# 5 at the time of writing. |
|
|
|
CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP |
|
|
|
|
|
|
|
reqs = [] |
|
|
|
|
|
|
|
# Request 7 different keys (1 to 7) of the `some.state` type. |
|
|
|
for req_id in range(CAP_COUNT + 2): |
|
|
|
reqs.append( |
|
|
|
ensureDeferred( |
|
|
|
self.state_datastore._get_state_for_group_using_inflight_cache( |
|
|
|
42, |
|
|
|
StateFilter.freeze( |
|
|
|
{"some.state": {str(req_id + 1)}}, include_others=False |
|
|
|
), |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
self.pump(by=0.1) |
|
|
|
|
|
|
|
# There should only be 6 calls to the database, not 7. |
|
|
|
self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) |
|
|
|
|
|
|
|
# Assert that the first 5 are exact requests for the individual pieces |
|
|
|
# wanted |
|
|
|
for req_id in range(CAP_COUNT): |
|
|
|
groups, sf, d = self.get_state_group_calls[req_id] |
|
|
|
self.assertEqual( |
|
|
|
sf, |
|
|
|
StateFilter.freeze( |
|
|
|
{"some.state": {str(req_id + 1)}}, include_others=False |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
# The 6th request should be the 'all' state filter |
|
|
|
groups, sf, d = self.get_state_group_calls[CAP_COUNT] |
|
|
|
self.assertEqual(sf, StateFilter.all()) |
|
|
|
|
|
|
|
# Complete the queries and check which requests complete as a result |
|
|
|
for req_id in range(CAP_COUNT): |
|
|
|
# This request should not have been completed yet |
|
|
|
self.assertFalse(reqs[req_id].called) |
|
|
|
|
|
|
|
groups, sf, d = self.get_state_group_calls[req_id] |
|
|
|
self._complete_request_fake(groups, sf, d) |
|
|
|
|
|
|
|
# This should have only completed this one request |
|
|
|
self.assertTrue(reqs[req_id].called) |
|
|
|
|
|
|
|
# Now complete the final query; the last 2 requests should complete |
|
|
|
# as a result |
|
|
|
self.assertFalse(reqs[CAP_COUNT].called) |
|
|
|
self.assertFalse(reqs[CAP_COUNT + 1].called) |
|
|
|
groups, sf, d = self.get_state_group_calls[CAP_COUNT] |
|
|
|
self._complete_request_fake(groups, sf, d) |
|
|
|
self.assertTrue(reqs[CAP_COUNT].called) |
|
|
|
self.assertTrue(reqs[CAP_COUNT + 1].called) |