|
|
@@ -8,7 +8,10 @@ from tests import unittest |
|
|
|
class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
def test_allowed_via_can_do_action(self): |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
allowed, time_allowed = self.get_success_or_raise( |
|
|
|
limiter.can_do_action(None, key="test_id", _time_now_s=0) |
|
|
@@ -30,7 +33,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): |
|
|
|
appservice = ApplicationService( |
|
|
|
None, |
|
|
|
token="fake_token", |
|
|
|
id="foo", |
|
|
|
rate_limited=True, |
|
|
|
sender="@as:example.com", |
|
|
@@ -38,7 +41,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
as_requester = create_requester("@user:example.com", app_service=appservice) |
|
|
|
|
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
allowed, time_allowed = self.get_success_or_raise( |
|
|
|
limiter.can_do_action(as_requester, _time_now_s=0) |
|
|
@@ -60,7 +66,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_allowed_appservice_via_can_requester_do_action(self): |
|
|
|
appservice = ApplicationService( |
|
|
|
None, |
|
|
|
token="fake_token", |
|
|
|
id="foo", |
|
|
|
rate_limited=False, |
|
|
|
sender="@as:example.com", |
|
|
@@ -68,7 +74,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
as_requester = create_requester("@user:example.com", app_service=appservice) |
|
|
|
|
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
allowed, time_allowed = self.get_success_or_raise( |
|
|
|
limiter.can_do_action(as_requester, _time_now_s=0) |
|
|
@@ -90,7 +99,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_allowed_via_ratelimit(self): |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
|
|
|
|
# Shouldn't raise |
|
|
@@ -114,7 +126,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
""" |
|
|
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
|
|
|
|
# First attempt should be allowed |
|
|
@@ -160,7 +175,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
""" |
|
|
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
|
|
|
|
# First attempt should be allowed |
|
|
@@ -188,7 +206,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_pruning(self): |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=1, |
|
|
|
) |
|
|
|
self.get_success_or_raise( |
|
|
|
limiter.can_do_action(None, key="test_id_1", _time_now_s=0) |
|
|
@@ -223,7 +244,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) |
|
|
|
limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) |
|
|
|
|
|
|
|
# Shouldn't raise |
|
|
|
for _ in range(20): |
|
|
@@ -231,7 +252,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_multiple_actions(self): |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=3, |
|
|
|
) |
|
|
|
# Test that 4 actions aren't allowed with a maximum burst of 3. |
|
|
|
allowed, time_allowed = self.get_success_or_raise( |
|
|
@@ -295,7 +319,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
extra tokens by timing requests. |
|
|
|
""" |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=3, |
|
|
|
) |
|
|
|
|
|
|
|
def consume_at(time: float) -> bool: |
|
|
@@ -317,7 +344,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_record_action_which_doesnt_fill_bucket(self) -> None: |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=3, |
|
|
|
) |
|
|
|
|
|
|
|
# Observe two actions, leaving room in the bucket for one more. |
|
|
@@ -337,7 +367,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_record_action_which_fills_bucket(self) -> None: |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=3, |
|
|
|
) |
|
|
|
|
|
|
|
# Observe three actions, filling up the bucket. |
|
|
@@ -363,7 +396,10 @@ class TestRatelimiter(unittest.HomeserverTestCase): |
|
|
|
|
|
|
|
def test_record_action_which_overfills_bucket(self) -> None: |
|
|
|
limiter = Ratelimiter( |
|
|
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 |
|
|
|
store=self.hs.get_datastores().main, |
|
|
|
clock=self.clock, |
|
|
|
rate_hz=0.1, |
|
|
|
burst_count=3, |
|
|
|
) |
|
|
|
|
|
|
|
# Observe four actions, exceeding the bucket. |
|
|
|