@@ -0,0 +1 @@ | |||
HTTP tests have been refactored to contain less boilerplate. |
@@ -19,24 +19,17 @@ import json | |||
from mock import Mock | |||
from synapse.http.server import JsonResource | |||
from synapse.rest.client.v1.admin import register_servlets | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import ( | |||
ThreadedMemoryReactorClock, | |||
make_request, | |||
render, | |||
setup_test_homeserver, | |||
) | |||
class UserRegisterTestCase(unittest.TestCase): | |||
def setUp(self): | |||
class UserRegisterTestCase(unittest.HomeserverTestCase): | |||
servlets = [register_servlets] | |||
def make_homeserver(self, reactor, clock): | |||
self.clock = ThreadedMemoryReactorClock() | |||
self.hs_clock = Clock(self.clock) | |||
self.url = "/_matrix/client/r0/admin/register" | |||
self.registration_handler = Mock() | |||
@@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase): | |||
self.secrets = Mock() | |||
self.hs = setup_test_homeserver( | |||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock | |||
) | |||
self.hs = self.setup_test_homeserver() | |||
self.hs.config.registration_shared_secret = u"shared" | |||
self.hs.get_media_repository = Mock() | |||
self.hs.get_deactivate_account_handler = Mock() | |||
self.resource = JsonResource(self.hs) | |||
register_servlets(self.hs, self.resource) | |||
return self.hs | |||
def test_disabled(self): | |||
""" | |||
@@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
""" | |||
self.hs.config.registration_shared_secret = None | |||
request, channel = make_request("POST", self.url, b'{}') | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, b'{}') | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual( | |||
@@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
self.hs.get_secrets = Mock(return_value=secrets) | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
self.assertEqual(channel.json_body, {"nonce": "abcd"}) | |||
@@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase): | |||
Calling GET on the endpoint will return a randomised nonce, which will | |||
only last for SALT_TIMEOUT (60s). | |||
""" | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
nonce = channel.json_body["nonce"] | |||
# 59 seconds | |||
self.clock.advance(59) | |||
self.reactor.advance(59) | |||
body = json.dumps({"nonce": nonce}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('username must be specified', channel.json_body["error"]) | |||
# 61 seconds | |||
self.clock.advance(2) | |||
self.reactor.advance(2) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('unrecognised nonce', channel.json_body["error"]) | |||
@@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
""" | |||
Only the provided nonce can be used, as it's checked in the MAC. | |||
""" | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
nonce = channel.json_body["nonce"] | |||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) | |||
@@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
"mac": want_mac, | |||
} | |||
) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual("HMAC incorrect", channel.json_body["error"]) | |||
@@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
When the correct nonce is provided, and the right key is provided, the | |||
user is registered. | |||
""" | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
nonce = channel.json_body["nonce"] | |||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) | |||
@@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
"mac": want_mac, | |||
} | |||
) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual("@bob:test", channel.json_body["user_id"]) | |||
@@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
""" | |||
A valid unrecognised nonce. | |||
""" | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
nonce = channel.json_body["nonce"] | |||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) | |||
@@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase): | |||
"mac": want_mac, | |||
} | |||
) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual("@bob:test", channel.json_body["user_id"]) | |||
# Now, try and reuse it | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('unrecognised nonce', channel.json_body["error"]) | |||
@@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
""" | |||
def nonce(): | |||
request, channel = make_request("GET", self.url) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("GET", self.url) | |||
self.render(request) | |||
return channel.json_body["nonce"] | |||
# | |||
@@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase): | |||
# Must be present | |||
body = json.dumps({}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('nonce must be specified', channel.json_body["error"]) | |||
@@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase): | |||
# Must be present | |||
body = json.dumps({"nonce": nonce()}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('username must be specified', channel.json_body["error"]) | |||
# Must be a string | |||
body = json.dumps({"nonce": nonce(), "username": 1234}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid username', channel.json_body["error"]) | |||
# Must not have null bytes | |||
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid username', channel.json_body["error"]) | |||
# Must not have null bytes | |||
body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid username', channel.json_body["error"]) | |||
@@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase): | |||
# Must be present | |||
body = json.dumps({"nonce": nonce(), "username": "a"}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('password must be specified', channel.json_body["error"]) | |||
# Must be a string | |||
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid password', channel.json_body["error"]) | |||
@@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase): | |||
body = json.dumps( | |||
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} | |||
) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid password', channel.json_body["error"]) | |||
# Super long | |||
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) | |||
request, channel = make_request("POST", self.url, body.encode('utf8')) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request("POST", self.url, body.encode('utf8')) | |||
self.render(request) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual('Invalid password', channel.json_body["error"]) |
@@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase): | |||
) | |||
handlers = Mock(registration_handler=self.registration_handler) | |||
self.clock = MemoryReactorClock() | |||
self.hs_clock = Clock(self.clock) | |||
self.reactor = MemoryReactorClock() | |||
self.hs_clock = Clock(self.reactor) | |||
self.hs = self.hs = setup_test_homeserver( | |||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock | |||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor | |||
) | |||
self.hs.get_datastore = Mock(return_value=self.datastore) | |||
self.hs.get_handlers = Mock(return_value=handlers) | |||
@@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase): | |||
return_value=(user_id, token) | |||
) | |||
request, channel = make_request(b"POST", url, request_data) | |||
render(request, res, self.clock) | |||
request, channel = make_request(self.reactor, b"POST", url, request_data) | |||
render(request, res, self.reactor) | |||
self.assertEquals(channel.result["code"], b"200") | |||
@@ -169,7 +169,7 @@ class RestHelper(object): | |||
path = path + "?access_token=%s" % tok | |||
request, channel = make_request( | |||
"POST", path, json.dumps(content).encode('utf8') | |||
self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') | |||
) | |||
render(request, self.resource, self.hs.get_reactor()) | |||
@@ -217,7 +217,9 @@ class RestHelper(object): | |||
data = {"membership": membership} | |||
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) | |||
request, channel = make_request( | |||
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') | |||
) | |||
render(request, self.resource, self.hs.get_reactor()) | |||
@@ -228,18 +230,6 @@ class RestHelper(object): | |||
self.auth_user_id = temp_id | |||
@defer.inlineCallbacks | |||
def register(self, user_id): | |||
(code, response) = yield self.mock_resource.trigger( | |||
"POST", | |||
"/_matrix/client/r0/register", | |||
json.dumps( | |||
{"user": user_id, "password": "test", "type": "m.login.password"} | |||
), | |||
) | |||
self.assertEquals(200, code) | |||
defer.returnValue(response) | |||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): | |||
if txn_id is None: | |||
txn_id = "m%s" % (str(time.time())) | |||
@@ -251,7 +241,9 @@ class RestHelper(object): | |||
if tok: | |||
path = path + "?access_token=%s" % tok | |||
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) | |||
request, channel = make_request( | |||
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') | |||
) | |||
render(request, self.resource, self.hs.get_reactor()) | |||
assert int(channel.result["code"]) == expect_code, ( | |||
@@ -13,84 +13,47 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import synapse.types | |||
from synapse.api.errors import Codes | |||
from synapse.http.server import JsonResource | |||
from synapse.rest.client.v2_alpha import filter | |||
from synapse.types import UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import ( | |||
ThreadedMemoryReactorClock as MemoryReactorClock, | |||
make_request, | |||
render, | |||
setup_test_homeserver, | |||
) | |||
PATH_PREFIX = "/_matrix/client/v2_alpha" | |||
class FilterTestCase(unittest.TestCase): | |||
class FilterTestCase(unittest.HomeserverTestCase): | |||
USER_ID = "@apple:test" | |||
user_id = "@apple:test" | |||
hijack_auth = True | |||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} | |||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' | |||
TO_REGISTER = [filter] | |||
servlets = [filter.register_servlets] | |||
def setUp(self): | |||
self.clock = MemoryReactorClock() | |||
self.hs_clock = Clock(self.clock) | |||
self.hs = setup_test_homeserver( | |||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock | |||
) | |||
self.auth = self.hs.get_auth() | |||
def get_user_by_access_token(token=None, allow_guest=False): | |||
return { | |||
"user": UserID.from_string(self.USER_ID), | |||
"token_id": 1, | |||
"is_guest": False, | |||
} | |||
def get_user_by_req(request, allow_guest=False, rights="access"): | |||
return synapse.types.create_requester( | |||
UserID.from_string(self.USER_ID), 1, False, None | |||
) | |||
self.auth.get_user_by_access_token = get_user_by_access_token | |||
self.auth.get_user_by_req = get_user_by_req | |||
self.store = self.hs.get_datastore() | |||
self.filtering = self.hs.get_filtering() | |||
self.resource = JsonResource(self.hs) | |||
for r in self.TO_REGISTER: | |||
r.register_servlets(self.hs, self.resource) | |||
def prepare(self, reactor, clock, hs): | |||
self.filtering = hs.get_filtering() | |||
self.store = hs.get_datastore() | |||
def test_add_filter(self): | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
"POST", | |||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), | |||
"/_matrix/client/r0/user/%s/filter" % (self.user_id), | |||
self.EXAMPLE_FILTER_JSON, | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"200") | |||
self.assertEqual(channel.json_body, {"filter_id": "0"}) | |||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) | |||
self.clock.advance(0) | |||
self.pump() | |||
self.assertEquals(filter.result, self.EXAMPLE_FILTER) | |||
def test_add_filter_for_other_user(self): | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
"POST", | |||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), | |||
self.EXAMPLE_FILTER_JSON, | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"403") | |||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) | |||
@@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase): | |||
def test_add_filter_non_local_user(self): | |||
_is_mine = self.hs.is_mine | |||
self.hs.is_mine = lambda target_user: False | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
"POST", | |||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), | |||
"/_matrix/client/r0/user/%s/filter" % (self.user_id), | |||
self.EXAMPLE_FILTER_JSON, | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.hs.is_mine = _is_mine | |||
self.assertEqual(channel.result["code"], b"403") | |||
@@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase): | |||
filter_id = self.filtering.add_user_filter( | |||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER | |||
) | |||
self.clock.advance(1) | |||
self.reactor.advance(1) | |||
filter_id = filter_id.result | |||
request, channel = make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) | |||
request, channel = self.make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"200") | |||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) | |||
def test_get_filter_non_existant(self): | |||
request, channel = make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) | |||
request, channel = self.make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"400") | |||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) | |||
@@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase): | |||
# Currently invalid params do not have an appropriate errcode | |||
# in errors.py | |||
def test_get_filter_invalid_id(self): | |||
request, channel = make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) | |||
request, channel = self.make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"400") | |||
# No ID also returns an invalid_id error | |||
def test_get_filter_no_id(self): | |||
request, channel = make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) | |||
request, channel = self.make_request( | |||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEqual(channel.result["code"], b"400") |
@@ -3,22 +3,19 @@ import json | |||
from mock import Mock | |||
from twisted.python import failure | |||
from twisted.test.proto_helpers import MemoryReactorClock | |||
from synapse.api.errors import InteractiveAuthIncompleteError | |||
from synapse.http.server import JsonResource | |||
from synapse.rest.client.v2_alpha.register import register_servlets | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import make_request, render, setup_test_homeserver | |||
class RegisterRestServletTestCase(unittest.TestCase): | |||
def setUp(self): | |||
class RegisterRestServletTestCase(unittest.HomeserverTestCase): | |||
servlets = [register_servlets] | |||
def make_homeserver(self, reactor, clock): | |||
self.clock = MemoryReactorClock() | |||
self.hs_clock = Clock(self.clock) | |||
self.url = b"/_matrix/client/r0/register" | |||
self.appservice = None | |||
@@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
identity_handler=self.identity_handler, | |||
login_handler=self.login_handler, | |||
) | |||
self.hs = setup_test_homeserver( | |||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock | |||
) | |||
self.hs = self.setup_test_homeserver() | |||
self.hs.get_auth = Mock(return_value=self.auth) | |||
self.hs.get_handlers = Mock(return_value=self.handlers) | |||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) | |||
@@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
self.hs.config.registrations_require_3pid = [] | |||
self.hs.config.auto_join_rooms = [] | |||
self.resource = JsonResource(self.hs) | |||
register_servlets(self.hs, self.resource) | |||
return self.hs | |||
def test_POST_appservice_registration_valid(self): | |||
user_id = "@kermit:muppet" | |||
@@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) | |||
request_data = json.dumps({"username": "kermit"}) | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"200", channel.result) | |||
det_data = { | |||
@@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
def test_POST_appservice_registration_invalid(self): | |||
self.appservice = None # no application service exists | |||
request_data = json.dumps({"username": "kermit"}) | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data | |||
) | |||
render(request, self.resource, self.clock) | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"401", channel.result) | |||
def test_POST_bad_password(self): | |||
request_data = json.dumps({"username": "kermit", "password": 666}) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"400", channel.result) | |||
self.assertEquals(channel.json_body["error"], "Invalid password") | |||
def test_POST_bad_username(self): | |||
request_data = json.dumps({"username": 777, "password": "monkey"}) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"400", channel.result) | |||
self.assertEquals(channel.json_body["error"], "Invalid username") | |||
@@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) | |||
self.device_handler.check_device_registered = Mock(return_value=device_id) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
det_data = { | |||
"user_id": user_id, | |||
@@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) | |||
self.registration_handler.register = Mock(return_value=("@user:id", "t")) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"403", channel.result) | |||
self.assertEquals(channel.json_body["error"], "Registration has been disabled") | |||
@@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
self.hs.config.allow_guest_access = True | |||
self.registration_handler.register = Mock(return_value=(user_id, None)) | |||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") | |||
self.render(request) | |||
det_data = { | |||
"user_id": user_id, | |||
@@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
def test_POST_disabled_guest_registration(self): | |||
self.hs.config.allow_guest_access = False | |||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") | |||
render(request, self.resource, self.clock) | |||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") | |||
self.render(request) | |||
self.assertEquals(channel.result["code"], b"403", channel.result) | |||
self.assertEquals(channel.json_body["error"], "Guest access is disabled") |
@@ -34,6 +34,7 @@ class FakeChannel(object): | |||
wire). | |||
""" | |||
_reactor = attr.ib() | |||
result = attr.ib(default=attr.Factory(dict)) | |||
_producer = None | |||
@@ -63,6 +64,15 @@ class FakeChannel(object): | |||
def registerProducer(self, producer, streaming): | |||
self._producer = producer | |||
self.producerStreaming = streaming | |||
def _produce(): | |||
if self._producer: | |||
self._producer.resumeProducing() | |||
self._reactor.callLater(0.1, _produce) | |||
if not streaming: | |||
self._reactor.callLater(0.0, _produce) | |||
def unregisterProducer(self): | |||
if self._producer is None: | |||
@@ -105,7 +115,13 @@ class FakeSite: | |||
def make_request( | |||
method, path, content=b"", access_token=None, request=SynapseRequest, shorthand=True | |||
reactor, | |||
method, | |||
path, | |||
content=b"", | |||
access_token=None, | |||
request=SynapseRequest, | |||
shorthand=True, | |||
): | |||
""" | |||
Make a web request using the given method and path, feed it the | |||
@@ -138,7 +154,7 @@ def make_request( | |||
content = content.encode('utf8') | |||
site = FakeSite() | |||
channel = FakeChannel() | |||
channel = FakeChannel(reactor) | |||
req = request(site, channel) | |||
req.process = lambda: b"" | |||
@@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock | |||
from synapse.api.constants import LoginType | |||
from synapse.api.errors import Codes, HttpResponseException, SynapseError | |||
from synapse.http.server import JsonResource | |||
from synapse.rest.client.v2_alpha import register, sync | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import ( | |||
ThreadedMemoryReactorClock, | |||
make_request, | |||
render, | |||
setup_test_homeserver, | |||
) | |||
class TestMauLimit(unittest.TestCase): | |||
def setUp(self): | |||
self.reactor = ThreadedMemoryReactorClock() | |||
self.clock = Clock(self.reactor) | |||
class TestMauLimit(unittest.HomeserverTestCase): | |||
self.hs = setup_test_homeserver( | |||
self.addCleanup, | |||
servlets = [register.register_servlets, sync.register_servlets] | |||
def make_homeserver(self, reactor, clock): | |||
self.hs = self.setup_test_homeserver( | |||
"red", | |||
http_client=None, | |||
clock=self.clock, | |||
reactor=self.reactor, | |||
federation_client=Mock(), | |||
ratelimiter=NonCallableMock(spec_set=["send_message"]), | |||
) | |||
@@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase): | |||
self.hs.config.server_notices_mxid_display_name = None | |||
self.hs.config.server_notices_mxid_avatar_url = None | |||
self.hs.config.server_notices_room_name = "Test Server Notice Room" | |||
self.resource = JsonResource(self.hs) | |||
register.register_servlets(self.hs, self.resource) | |||
sync.register_servlets(self.hs, self.resource) | |||
return self.hs | |||
def test_simple_deny_mau(self): | |||
# Create and sync so that the MAU counts get updated | |||
@@ -193,8 +180,8 @@ class TestMauLimit(unittest.TestCase): | |||
} | |||
) | |||
request, channel = make_request("POST", "/register", request_data) | |||
render(request, self.resource, self.reactor) | |||
request, channel = self.make_request("POST", "/register", request_data) | |||
self.render(request) | |||
if channel.code != 200: | |||
raise HttpResponseException( | |||
@@ -206,10 +193,10 @@ class TestMauLimit(unittest.TestCase): | |||
return access_token | |||
def do_sync_for_user(self, token): | |||
request, channel = make_request( | |||
request, channel = self.make_request( | |||
"GET", "/sync", access_token=token | |||
) | |||
render(request, self.resource, self.reactor) | |||
self.render(request) | |||
if channel.code != 200: | |||
raise HttpResponseException( | |||
@@ -57,7 +57,9 @@ class JsonResourceTests(unittest.TestCase): | |||
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback | |||
) | |||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") | |||
request, channel = make_request( | |||
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" | |||
) | |||
render(request, res, self.reactor) | |||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) | |||
@@ -75,7 +77,7 @@ class JsonResourceTests(unittest.TestCase): | |||
res = JsonResource(self.homeserver) | |||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) | |||
request, channel = make_request(b"GET", b"/_matrix/foo") | |||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") | |||
render(request, res, self.reactor) | |||
self.assertEqual(channel.result["code"], b'500') | |||
@@ -98,7 +100,7 @@ class JsonResourceTests(unittest.TestCase): | |||
res = JsonResource(self.homeserver) | |||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) | |||
request, channel = make_request(b"GET", b"/_matrix/foo") | |||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") | |||
render(request, res, self.reactor) | |||
self.assertEqual(channel.result["code"], b'500') | |||
@@ -115,7 +117,7 @@ class JsonResourceTests(unittest.TestCase): | |||
res = JsonResource(self.homeserver) | |||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) | |||
request, channel = make_request(b"GET", b"/_matrix/foo") | |||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") | |||
render(request, res, self.reactor) | |||
self.assertEqual(channel.result["code"], b'403') | |||
@@ -136,7 +138,7 @@ class JsonResourceTests(unittest.TestCase): | |||
res = JsonResource(self.homeserver) | |||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) | |||
request, channel = make_request(b"GET", b"/_matrix/foobar") | |||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") | |||
render(request, res, self.reactor) | |||
self.assertEqual(channel.result["code"], b'400') | |||
@@ -23,7 +23,6 @@ from synapse.rest.client.v2_alpha.register import register_servlets | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import make_request | |||
class TermsTestCase(unittest.HomeserverTestCase): | |||
@@ -92,7 +91,7 @@ class TermsTestCase(unittest.HomeserverTestCase): | |||
self.registration_handler.check_username = Mock(return_value=True) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
# We don't bother checking that the response is correct - we'll leave that to | |||
@@ -110,7 +109,7 @@ class TermsTestCase(unittest.HomeserverTestCase): | |||
}, | |||
} | |||
) | |||
request, channel = make_request(b"POST", self.url, request_data) | |||
request, channel = self.make_request(b"POST", self.url, request_data) | |||
self.render(request) | |||
# We're interested in getting a response that looks like a successful | |||
@@ -189,11 +189,11 @@ class HomeserverTestCase(TestCase): | |||
for servlet in self.servlets: | |||
servlet(self.hs, self.resource) | |||
if hasattr(self, "user_id"): | |||
from tests.rest.client.v1.utils import RestHelper | |||
from tests.rest.client.v1.utils import RestHelper | |||
self.helper = RestHelper(self.hs, self.resource, self.user_id) | |||
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) | |||
if hasattr(self, "user_id"): | |||
if self.hijack_auth: | |||
def get_user_by_access_token(token=None, allow_guest=False): | |||
@@ -285,7 +285,9 @@ class HomeserverTestCase(TestCase): | |||
if isinstance(content, dict): | |||
content = json.dumps(content).encode('utf8') | |||
return make_request(method, path, content, access_token, request, shorthand) | |||
return make_request( | |||
self.reactor, method, path, content, access_token, request, shorthand | |||
) | |||
def render(self, request): | |||
""" | |||