@@ -35,10 +35,6 @@ matrix: | |||
- python: 3.6 | |||
env: TOX_ENV=check-newsfragment | |||
allow_failures: | |||
- python: 2.7 | |||
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" | |||
install: | |||
- pip install tox | |||
@@ -0,0 +1 @@ | |||
The test suite now passes on PostgreSQL. |
@@ -1,5 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright 2016 OpenMarket Ltd | |||
# Copyright 2018 New Vector Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -13,79 +14,79 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
import synapse.api.errors | |||
import synapse.handlers.device | |||
import synapse.storage | |||
from tests import unittest, utils | |||
from tests import unittest | |||
user1 = "@boris:aaa" | |||
user2 = "@theresa:bbb" | |||
class DeviceTestCase(unittest.TestCase): | |||
def __init__(self, *args, **kwargs): | |||
super(DeviceTestCase, self).__init__(*args, **kwargs) | |||
self.store = None # type: synapse.storage.DataStore | |||
self.handler = None # type: synapse.handlers.device.DeviceHandler | |||
self.clock = None # type: utils.MockClock | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
hs = yield utils.setup_test_homeserver(self.addCleanup) | |||
class DeviceTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
hs = self.setup_test_homeserver("server", http_client=None) | |||
self.handler = hs.get_device_handler() | |||
self.store = hs.get_datastore() | |||
self.clock = hs.get_clock() | |||
return hs | |||
def prepare(self, reactor, clock, hs): | |||
# These tests assume that it starts 1000 seconds in. | |||
self.reactor.advance(1000) | |||
@defer.inlineCallbacks | |||
def test_device_is_created_if_doesnt_exist(self): | |||
res = yield self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="display name", | |||
res = self.get_success( | |||
self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="display name", | |||
) | |||
) | |||
self.assertEqual(res, "fco") | |||
dev = yield self.handler.store.get_device("@boris:foo", "fco") | |||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) | |||
self.assertEqual(dev["display_name"], "display name") | |||
@defer.inlineCallbacks | |||
def test_device_is_preserved_if_exists(self): | |||
res1 = yield self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="display name", | |||
res1 = self.get_success( | |||
self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="display name", | |||
) | |||
) | |||
self.assertEqual(res1, "fco") | |||
res2 = yield self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="new display name", | |||
res2 = self.get_success( | |||
self.handler.check_device_registered( | |||
user_id="@boris:foo", | |||
device_id="fco", | |||
initial_device_display_name="new display name", | |||
) | |||
) | |||
self.assertEqual(res2, "fco") | |||
dev = yield self.handler.store.get_device("@boris:foo", "fco") | |||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) | |||
self.assertEqual(dev["display_name"], "display name") | |||
@defer.inlineCallbacks | |||
def test_device_id_is_made_up_if_unspecified(self): | |||
device_id = yield self.handler.check_device_registered( | |||
user_id="@theresa:foo", | |||
device_id=None, | |||
initial_device_display_name="display", | |||
device_id = self.get_success( | |||
self.handler.check_device_registered( | |||
user_id="@theresa:foo", | |||
device_id=None, | |||
initial_device_display_name="display", | |||
) | |||
) | |||
dev = yield self.handler.store.get_device("@theresa:foo", device_id) | |||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) | |||
self.assertEqual(dev["display_name"], "display") | |||
@defer.inlineCallbacks | |||
def test_get_devices_by_user(self): | |||
yield self._record_users() | |||
self._record_users() | |||
res = self.get_success(self.handler.get_devices_by_user(user1)) | |||
res = yield self.handler.get_devices_by_user(user1) | |||
self.assertEqual(3, len(res)) | |||
device_map = {d["device_id"]: d for d in res} | |||
self.assertDictContainsSubset( | |||
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase): | |||
device_map["abc"], | |||
) | |||
@defer.inlineCallbacks | |||
def test_get_device(self): | |||
yield self._record_users() | |||
self._record_users() | |||
res = yield self.handler.get_device(user1, "abc") | |||
res = self.get_success(self.handler.get_device(user1, "abc")) | |||
self.assertDictContainsSubset( | |||
{ | |||
"user_id": user1, | |||
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase): | |||
res, | |||
) | |||
@defer.inlineCallbacks | |||
def test_delete_device(self): | |||
yield self._record_users() | |||
self._record_users() | |||
# delete the device | |||
yield self.handler.delete_device(user1, "abc") | |||
self.get_success(self.handler.delete_device(user1, "abc")) | |||
# check the device was deleted | |||
with self.assertRaises(synapse.api.errors.NotFoundError): | |||
yield self.handler.get_device(user1, "abc") | |||
res = self.handler.get_device(user1, "abc") | |||
self.pump() | |||
self.assertIsInstance( | |||
self.failureResultOf(res).value, synapse.api.errors.NotFoundError | |||
) | |||
# we'd like to check the access token was invalidated, but that's a | |||
# bit of a PITA. | |||
@defer.inlineCallbacks | |||
def test_update_device(self): | |||
yield self._record_users() | |||
self._record_users() | |||
update = {"display_name": "new display"} | |||
yield self.handler.update_device(user1, "abc", update) | |||
self.get_success(self.handler.update_device(user1, "abc", update)) | |||
res = yield self.handler.get_device(user1, "abc") | |||
res = self.get_success(self.handler.get_device(user1, "abc")) | |||
self.assertEqual(res["display_name"], "new display") | |||
@defer.inlineCallbacks | |||
def test_update_unknown_device(self): | |||
update = {"display_name": "new_display"} | |||
with self.assertRaises(synapse.api.errors.NotFoundError): | |||
yield self.handler.update_device("user_id", "unknown_device_id", update) | |||
res = self.handler.update_device("user_id", "unknown_device_id", update) | |||
self.pump() | |||
self.assertIsInstance( | |||
self.failureResultOf(res).value, synapse.api.errors.NotFoundError | |||
) | |||
@defer.inlineCallbacks | |||
def _record_users(self): | |||
# check this works for both devices which have a recorded client_ip, | |||
# and those which don't. | |||
yield self._record_user(user1, "xyz", "display 0") | |||
yield self._record_user(user1, "fco", "display 1", "token1", "ip1") | |||
yield self._record_user(user1, "abc", "display 2", "token2", "ip2") | |||
yield self._record_user(user1, "abc", "display 2", "token3", "ip3") | |||
self._record_user(user1, "xyz", "display 0") | |||
self._record_user(user1, "fco", "display 1", "token1", "ip1") | |||
self._record_user(user1, "abc", "display 2", "token2", "ip2") | |||
self._record_user(user1, "abc", "display 2", "token3", "ip3") | |||
self._record_user(user2, "def", "dispkay", "token4", "ip4") | |||
yield self._record_user(user2, "def", "dispkay", "token4", "ip4") | |||
self.reactor.advance(10000) | |||
@defer.inlineCallbacks | |||
def _record_user( | |||
self, user_id, device_id, display_name, access_token=None, ip=None | |||
): | |||
device_id = yield self.handler.check_device_registered( | |||
user_id=user_id, | |||
device_id=device_id, | |||
initial_device_display_name=display_name, | |||
device_id = self.get_success( | |||
self.handler.check_device_registered( | |||
user_id=user_id, | |||
device_id=device_id, | |||
initial_device_display_name=display_name, | |||
) | |||
) | |||
if ip is not None: | |||
yield self.store.insert_client_ip( | |||
user_id, access_token, ip, "user_agent", device_id | |||
self.get_success( | |||
self.store.insert_client_ip( | |||
user_id, access_token, ip, "user_agent", device_id | |||
) | |||
) | |||
self.clock.advance_time(1000) | |||
self.reactor.advance(1000) |
@@ -1,4 +1,5 @@ | |||
# Copyright 2016 OpenMarket Ltd | |||
# Copyright 2018 New Vector Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -11,89 +12,91 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import tempfile | |||
from mock import Mock, NonCallableMock | |||
from twisted.internet import defer, reactor | |||
from twisted.internet.defer import Deferred | |||
import attr | |||
from synapse.replication.tcp.client import ( | |||
ReplicationClientFactory, | |||
ReplicationClientHandler, | |||
) | |||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory | |||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable | |||
from tests import unittest | |||
from tests.utils import setup_test_homeserver | |||
class TestReplicationClientHandler(ReplicationClientHandler): | |||
"""Overrides on_rdata so that we can wait for it to happen""" | |||
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def __init__(self, store): | |||
super(TestReplicationClientHandler, self).__init__(store) | |||
self._rdata_awaiters = [] | |||
def await_replication(self): | |||
d = Deferred() | |||
self._rdata_awaiters.append(d) | |||
return make_deferred_yieldable(d) | |||
def on_rdata(self, stream_name, token, rows): | |||
awaiters = self._rdata_awaiters | |||
self._rdata_awaiters = [] | |||
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows) | |||
with PreserveLoggingContext(): | |||
for a in awaiters: | |||
a.callback(None) | |||
class BaseSlavedStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
self.hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
hs = self.setup_test_homeserver( | |||
"blue", | |||
http_client=None, | |||
federation_client=Mock(), | |||
ratelimiter=NonCallableMock(spec_set=["send_message"]), | |||
) | |||
self.hs.get_ratelimiter().send_message.return_value = (True, 0) | |||
hs.get_ratelimiter().send_message.return_value = (True, 0) | |||
return hs | |||
def prepare(self, reactor, clock, hs): | |||
self.master_store = self.hs.get_datastore() | |||
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) | |||
self.event_id = 0 | |||
server_factory = ReplicationStreamProtocolFactory(self.hs) | |||
# XXX: mktemp is unsafe and should never be used. but we're just a test. | |||
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") | |||
listener = reactor.listenUNIX(path, server_factory) | |||
self.addCleanup(listener.stopListening) | |||
self.streamer = server_factory.streamer | |||
self.replication_handler = TestReplicationClientHandler(self.slaved_store) | |||
self.replication_handler = ReplicationClientHandler(self.slaved_store) | |||
client_factory = ReplicationClientFactory( | |||
self.hs, "client_name", self.replication_handler | |||
) | |||
client_connector = reactor.connectUNIX(path, client_factory) | |||
self.addCleanup(client_factory.stopTrying) | |||
self.addCleanup(client_connector.disconnect) | |||
server = server_factory.buildProtocol(None) | |||
client = client_factory.buildProtocol(None) | |||
@attr.s | |||
class FakeTransport(object): | |||
other = attr.ib() | |||
disconnecting = False | |||
buffer = attr.ib(default=b'') | |||
def registerProducer(self, producer, streaming): | |||
self.producer = producer | |||
def _produce(): | |||
self.producer.resumeProducing() | |||
reactor.callLater(0.1, _produce) | |||
reactor.callLater(0.0, _produce) | |||
def write(self, byt): | |||
self.buffer = self.buffer + byt | |||
if getattr(self.other, "transport") is not None: | |||
self.other.dataReceived(self.buffer) | |||
self.buffer = b"" | |||
def writeSequence(self, seq): | |||
for x in seq: | |||
self.write(x) | |||
client.makeConnection(FakeTransport(server)) | |||
server.makeConnection(FakeTransport(client)) | |||
def replicate(self): | |||
"""Tell the master side of replication that something has happened, and then | |||
wait for the replication to occur. | |||
""" | |||
# xxx: should we be more specific in what we wait for? | |||
d = self.replication_handler.await_replication() | |||
self.streamer.on_notifier_poke() | |||
return d | |||
self.pump(0.1) | |||
@defer.inlineCallbacks | |||
def check(self, method, args, expected_result=None): | |||
master_result = yield getattr(self.master_store, method)(*args) | |||
slaved_result = yield getattr(self.slaved_store, method)(*args) | |||
master_result = self.get_success(getattr(self.master_store, method)(*args)) | |||
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) | |||
if expected_result is not None: | |||
self.assertEqual(master_result, expected_result) | |||
self.assertEqual(slaved_result, expected_result) | |||
@@ -12,9 +12,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore | |||
from ._base import BaseSlavedStoreTestCase | |||
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): | |||
STORE_TYPE = SlavedAccountDataStore | |||
@defer.inlineCallbacks | |||
def test_user_account_data(self): | |||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) | |||
yield self.replicate() | |||
yield self.check( | |||
self.get_success( | |||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) | |||
) | |||
self.replicate() | |||
self.check( | |||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} | |||
) | |||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) | |||
yield self.replicate() | |||
yield self.check( | |||
self.get_success( | |||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) | |||
) | |||
self.replicate() | |||
self.check( | |||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} | |||
) |
@@ -12,8 +12,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
from synapse.events import FrozenEvent, _EventInternalMetadata | |||
from synapse.events.snapshot import EventContext | |||
from synapse.replication.slave.storage.events import SlavedEventStore | |||
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
def tearDown(self): | |||
[unpatch() for unpatch in self.unpatches] | |||
@defer.inlineCallbacks | |||
def test_get_latest_event_ids_in_room(self): | |||
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) | |||
yield self.replicate() | |||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) | |||
create = self.persist(type="m.room.create", key="", creator=USER_ID) | |||
self.replicate() | |||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) | |||
join = yield self.persist( | |||
join = self.persist( | |||
type="m.room.member", | |||
key=USER_ID, | |||
membership="join", | |||
prev_events=[(create.event_id, {})], | |||
) | |||
yield self.replicate() | |||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) | |||
self.replicate() | |||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) | |||
@defer.inlineCallbacks | |||
def test_redactions(self): | |||
yield self.persist(type="m.room.create", key="", creator=USER_ID) | |||
yield self.persist(type="m.room.member", key=USER_ID, membership="join") | |||
self.persist(type="m.room.create", key="", creator=USER_ID) | |||
self.persist(type="m.room.member", key=USER_ID, membership="join") | |||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") | |||
yield self.replicate() | |||
yield self.check("get_event", [msg.event_id], msg) | |||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") | |||
self.replicate() | |||
self.check("get_event", [msg.event_id], msg) | |||
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) | |||
yield self.replicate() | |||
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) | |||
self.replicate() | |||
msg_dict = msg.get_dict() | |||
msg_dict["content"] = {} | |||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id | |||
msg_dict["unsigned"]["redacted_because"] = redaction | |||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) | |||
yield self.check("get_event", [msg.event_id], redacted) | |||
self.check("get_event", [msg.event_id], redacted) | |||
@defer.inlineCallbacks | |||
def test_backfilled_redactions(self): | |||
yield self.persist(type="m.room.create", key="", creator=USER_ID) | |||
yield self.persist(type="m.room.member", key=USER_ID, membership="join") | |||
self.persist(type="m.room.create", key="", creator=USER_ID) | |||
self.persist(type="m.room.member", key=USER_ID, membership="join") | |||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") | |||
yield self.replicate() | |||
yield self.check("get_event", [msg.event_id], msg) | |||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") | |||
self.replicate() | |||
self.check("get_event", [msg.event_id], msg) | |||
redaction = yield self.persist( | |||
redaction = self.persist( | |||
type="m.room.redaction", redacts=msg.event_id, backfill=True | |||
) | |||
yield self.replicate() | |||
self.replicate() | |||
msg_dict = msg.get_dict() | |||
msg_dict["content"] = {} | |||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id | |||
msg_dict["unsigned"]["redacted_because"] = redaction | |||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) | |||
yield self.check("get_event", [msg.event_id], redacted) | |||
self.check("get_event", [msg.event_id], redacted) | |||
@defer.inlineCallbacks | |||
def test_invites(self): | |||
yield self.persist(type="m.room.create", key="", creator=USER_ID) | |||
yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) | |||
event = yield self.persist( | |||
type="m.room.member", key=USER_ID_2, membership="invite" | |||
) | |||
yield self.replicate() | |||
yield self.check( | |||
self.persist(type="m.room.create", key="", creator=USER_ID) | |||
self.check("get_invited_rooms_for_user", [USER_ID_2], []) | |||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") | |||
self.replicate() | |||
self.check( | |||
"get_invited_rooms_for_user", | |||
[USER_ID_2], | |||
[ | |||
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
], | |||
) | |||
@defer.inlineCallbacks | |||
def test_push_actions_for_user(self): | |||
yield self.persist(type="m.room.create", key="", creator=USER_ID) | |||
yield self.persist(type="m.room.join", key=USER_ID, membership="join") | |||
yield self.persist( | |||
self.persist(type="m.room.create", key="", creator=USER_ID) | |||
self.persist(type="m.room.join", key=USER_ID, membership="join") | |||
self.persist( | |||
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" | |||
) | |||
event1 = yield self.persist( | |||
type="m.room.message", msgtype="m.text", body="hello" | |||
) | |||
yield self.replicate() | |||
yield self.check( | |||
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") | |||
self.replicate() | |||
self.check( | |||
"get_unread_event_push_actions_by_room_for_user", | |||
[ROOM_ID, USER_ID_2, event1.event_id], | |||
{"highlight_count": 0, "notify_count": 0}, | |||
) | |||
yield self.persist( | |||
self.persist( | |||
type="m.room.message", | |||
msgtype="m.text", | |||
body="world", | |||
push_actions=[(USER_ID_2, ["notify"])], | |||
) | |||
yield self.replicate() | |||
yield self.check( | |||
self.replicate() | |||
self.check( | |||
"get_unread_event_push_actions_by_room_for_user", | |||
[ROOM_ID, USER_ID_2, event1.event_id], | |||
{"highlight_count": 0, "notify_count": 1}, | |||
) | |||
yield self.persist( | |||
self.persist( | |||
type="m.room.message", | |||
msgtype="m.text", | |||
body="world", | |||
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) | |||
], | |||
) | |||
yield self.replicate() | |||
yield self.check( | |||
self.replicate() | |||
self.check( | |||
"get_unread_event_push_actions_by_room_for_user", | |||
[ROOM_ID, USER_ID_2, event1.event_id], | |||
{"highlight_count": 1, "notify_count": 2}, | |||
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
event_id = 0 | |||
@defer.inlineCallbacks | |||
def persist( | |||
self, | |||
sender=USER_ID, | |||
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
depth = self.event_id | |||
if not prev_events: | |||
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( | |||
room_id | |||
latest_event_ids = self.get_success( | |||
self.master_store.get_latest_event_ids_in_room(room_id) | |||
) | |||
prev_events = [(ev_id, {}) for ev_id in latest_event_ids] | |||
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
) | |||
else: | |||
state_handler = self.hs.get_state_handler() | |||
context = yield state_handler.compute_event_context(event) | |||
context = self.get_success(state_handler.compute_event_context(event)) | |||
yield self.master_store.add_push_actions_to_staging( | |||
self.master_store.add_push_actions_to_staging( | |||
event.event_id, {user_id: actions for user_id, actions in push_actions} | |||
) | |||
ordering = None | |||
if backfill: | |||
yield self.master_store.persist_events([(event, context)], backfilled=True) | |||
self.get_success( | |||
self.master_store.persist_events([(event, context)], backfilled=True) | |||
) | |||
else: | |||
ordering, _ = yield self.master_store.persist_event(event, context) | |||
ordering, _ = self.get_success( | |||
self.master_store.persist_event(event, context) | |||
) | |||
if ordering: | |||
event.internal_metadata.stream_ordering = ordering | |||
defer.returnValue(event) | |||
return event |
@@ -12,8 +12,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore | |||
from ._base import BaseSlavedStoreTestCase | |||
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): | |||
STORE_TYPE = SlavedReceiptsStore | |||
@defer.inlineCallbacks | |||
def test_receipt(self): | |||
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) | |||
yield self.master_store.insert_receipt( | |||
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} | |||
) | |||
yield self.replicate() | |||
yield self.check( | |||
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} | |||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) | |||
self.get_success( | |||
self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) | |||
) | |||
self.replicate() | |||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) |
@@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): | |||
clock.threadpool = ThreadPool() | |||
pool.threadpool = ThreadPool() | |||
pool.running = True | |||
return d | |||
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
self.as_yaml_files = [] | |||
config = Mock( | |||
app_service_config_files=self.as_yaml_files, | |||
event_cache_size=1, | |||
password_providers=[], | |||
) | |||
hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
config=config, | |||
federation_sender=Mock(), | |||
federation_client=Mock(), | |||
self.addCleanup, federation_sender=Mock(), federation_client=Mock() | |||
) | |||
hs.config.app_service_config_files = self.as_yaml_files | |||
hs.config.event_cache_size = 1 | |||
hs.config.password_providers = [] | |||
self.as_token = "token1" | |||
self.as_url = "some_url" | |||
self.as_id = "as1" | |||
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): | |||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") | |||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") | |||
# must be done after inserts | |||
self.store = ApplicationServiceStore(None, hs) | |||
self.store = ApplicationServiceStore(hs.get_db_conn(), hs) | |||
def tearDown(self): | |||
# TODO: suboptimal that we need to create files for tests! | |||
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
def setUp(self): | |||
self.as_yaml_files = [] | |||
config = Mock( | |||
app_service_config_files=self.as_yaml_files, | |||
event_cache_size=1, | |||
password_providers=[], | |||
) | |||
hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
config=config, | |||
federation_sender=Mock(), | |||
federation_client=Mock(), | |||
self.addCleanup, federation_sender=Mock(), federation_client=Mock() | |||
) | |||
hs.config.app_service_config_files = self.as_yaml_files | |||
hs.config.event_cache_size = 1 | |||
hs.config.password_providers = [] | |||
self.db_pool = hs.get_db_pool() | |||
self.engine = hs.database_engine | |||
self.as_list = [ | |||
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, | |||
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
self.as_yaml_files = [] | |||
self.store = TestTransactionStore(None, hs) | |||
self.store = TestTransactionStore(hs.get_db_conn(), hs) | |||
def _add_service(self, url, as_token, id): | |||
as_yaml = dict( | |||
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
self.as_yaml_files.append(as_token) | |||
def _set_state(self, id, state, txn=None): | |||
return self.db_pool.runQuery( | |||
"INSERT INTO application_services_state(as_id, state, last_txn) " | |||
"VALUES(?,?,?)", | |||
return self.db_pool.runOperation( | |||
self.engine.convert_param_style( | |||
"INSERT INTO application_services_state(as_id, state, last_txn) " | |||
"VALUES(?,?,?)" | |||
), | |||
(id, state, txn), | |||
) | |||
def _insert_txn(self, as_id, txn_id, events): | |||
return self.db_pool.runQuery( | |||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " | |||
"VALUES(?,?,?)", | |||
return self.db_pool.runOperation( | |||
self.engine.convert_param_style( | |||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " | |||
"VALUES(?,?,?)" | |||
), | |||
(as_id, txn_id, json.dumps([e.event_id for e in events])), | |||
) | |||
def _set_last_txn(self, as_id, txn_id): | |||
return self.db_pool.runQuery( | |||
"INSERT INTO application_services_state(as_id, last_txn, state) " | |||
"VALUES(?,?,?)", | |||
return self.db_pool.runOperation( | |||
self.engine.convert_param_style( | |||
"INSERT INTO application_services_state(as_id, last_txn, state) " | |||
"VALUES(?,?,?)" | |||
), | |||
(as_id, txn_id, ApplicationServiceState.UP), | |||
) | |||
@defer.inlineCallbacks | |||
def test_get_appservice_state_none(self): | |||
service = Mock(id=999) | |||
service = Mock(id="999") | |||
state = yield self.store.get_appservice_state(service) | |||
self.assertEquals(None, state) | |||
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
service = Mock(id=self.as_list[1]["id"]) | |||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) | |||
rows = yield self.db_pool.runQuery( | |||
"SELECT as_id FROM application_services_state WHERE state=?", | |||
self.engine.convert_param_style( | |||
"SELECT as_id FROM application_services_state WHERE state=?" | |||
), | |||
(ApplicationServiceState.DOWN,), | |||
) | |||
self.assertEquals(service.id, rows[0][0]) | |||
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) | |||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP) | |||
rows = yield self.db_pool.runQuery( | |||
"SELECT as_id FROM application_services_state WHERE state=?", | |||
self.engine.convert_param_style( | |||
"SELECT as_id FROM application_services_state WHERE state=?" | |||
), | |||
(ApplicationServiceState.UP,), | |||
) | |||
self.assertEquals(service.id, rows[0][0]) | |||
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) | |||
res = yield self.db_pool.runQuery( | |||
"SELECT last_txn FROM application_services_state WHERE as_id=?", | |||
self.engine.convert_param_style( | |||
"SELECT last_txn FROM application_services_state WHERE as_id=?" | |||
), | |||
(service.id,), | |||
) | |||
self.assertEquals(1, len(res)) | |||
self.assertEquals(txn_id, res[0][0]) | |||
res = yield self.db_pool.runQuery( | |||
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) | |||
self.engine.convert_param_style( | |||
"SELECT * FROM application_services_txns WHERE txn_id=?" | |||
), | |||
(txn_id,), | |||
) | |||
self.assertEquals(0, len(res)) | |||
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) | |||
res = yield self.db_pool.runQuery( | |||
"SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", | |||
self.engine.convert_param_style( | |||
"SELECT last_txn, state FROM application_services_state WHERE as_id=?" | |||
), | |||
(service.id,), | |||
) | |||
self.assertEquals(1, len(res)) | |||
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
self.assertEquals(ApplicationServiceState.UP, res[0][1]) | |||
res = yield self.db_pool.runQuery( | |||
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) | |||
self.engine.convert_param_style( | |||
"SELECT * FROM application_services_txns WHERE txn_id=?" | |||
), | |||
(txn_id,), | |||
) | |||
self.assertEquals(0, len(res)) | |||
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): | |||
f1 = self._write_config(suffix="1") | |||
f2 = self._write_config(suffix="2") | |||
config = Mock( | |||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] | |||
) | |||
hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
config=config, | |||
datastore=Mock(), | |||
federation_sender=Mock(), | |||
federation_client=Mock(), | |||
self.addCleanup, federation_sender=Mock(), federation_client=Mock() | |||
) | |||
ApplicationServiceStore(None, hs) | |||
hs.config.app_service_config_files = [f1, f2] | |||
hs.config.event_cache_size = 1 | |||
hs.config.password_providers = [] | |||
ApplicationServiceStore(hs.get_db_conn(), hs) | |||
@defer.inlineCallbacks | |||
def test_duplicate_ids(self): | |||
f1 = self._write_config(id="id", suffix="1") | |||
f2 = self._write_config(id="id", suffix="2") | |||
config = Mock( | |||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] | |||
) | |||
hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
config=config, | |||
datastore=Mock(), | |||
federation_sender=Mock(), | |||
federation_client=Mock(), | |||
self.addCleanup, federation_sender=Mock(), federation_client=Mock() | |||
) | |||
hs.config.app_service_config_files = [f1, f2] | |||
hs.config.event_cache_size = 1 | |||
hs.config.password_providers = [] | |||
with self.assertRaises(ConfigError) as cm: | |||
ApplicationServiceStore(None, hs) | |||
ApplicationServiceStore(hs.get_db_conn(), hs) | |||
e = cm.exception | |||
self.assertIn(f1, str(e)) | |||
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): | |||
f1 = self._write_config(as_token="as_token", suffix="1") | |||
f2 = self._write_config(as_token="as_token", suffix="2") | |||
config = Mock( | |||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] | |||
) | |||
hs = yield setup_test_homeserver( | |||
self.addCleanup, | |||
config=config, | |||
datastore=Mock(), | |||
federation_sender=Mock(), | |||
federation_client=Mock(), | |||
self.addCleanup, federation_sender=Mock(), federation_client=Mock() | |||
) | |||
hs.config.app_service_config_files = [f1, f2] | |||
hs.config.event_cache_size = 1 | |||
hs.config.password_providers = [] | |||
with self.assertRaises(ConfigError) as cm: | |||
ApplicationServiceStore(None, hs) | |||
ApplicationServiceStore(hs.get_db_conn(), hs) | |||
e = cm.exception | |||
self.assertIn(f1, str(e)) | |||
@@ -16,7 +16,6 @@ | |||
from twisted.internet import defer | |||
from synapse.storage.directory import DirectoryStore | |||
from synapse.types import RoomAlias, RoomID | |||
from tests import unittest | |||
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase): | |||
def setUp(self): | |||
hs = yield setup_test_homeserver(self.addCleanup) | |||
self.store = DirectoryStore(None, hs) | |||
self.store = hs.get_datastore() | |||
self.room = RoomID.from_string("!abcde:test") | |||
self.alias = RoomAlias.from_string("#my-room:test") | |||
@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): | |||
( | |||
"INSERT INTO events (" | |||
" room_id, event_id, type, depth, topological_ordering," | |||
" content, processed, outlier) " | |||
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" | |||
" content, processed, outlier, stream_ordering) " | |||
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)" | |||
), | |||
(room_id, event_id, i, i, True, False), | |||
(room_id, event_id, i, i, True, False, i), | |||
) | |||
txn.execute( | |||
@@ -13,25 +13,22 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
import tests.unittest | |||
import tests.utils | |||
from tests.utils import setup_test_homeserver | |||
from tests.unittest import HomeserverTestCase | |||
FORTY_DAYS = 40 * 24 * 60 * 60 | |||
class MonthlyActiveUsersTestCase(tests.unittest.TestCase): | |||
def __init__(self, *args, **kwargs): | |||
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) | |||
class MonthlyActiveUsersTestCase(HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
hs = self.setup_test_homeserver() | |||
self.store = hs.get_datastore() | |||
# Advance the clock a bit | |||
reactor.advance(FORTY_DAYS) | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
self.hs = yield setup_test_homeserver(self.addCleanup) | |||
self.store = self.hs.get_datastore() | |||
return hs | |||
@defer.inlineCallbacks | |||
def test_initialise_reserved_users(self): | |||
self.hs.config.max_mau_value = 5 | |||
user1 = "@user1:server" | |||
@@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): | |||
] | |||
user_num = len(threepids) | |||
yield self.store.register(user_id=user1, token="123", password_hash=None) | |||
yield self.store.register(user_id=user2, token="456", password_hash=None) | |||
self.store.register(user_id=user1, token="123", password_hash=None) | |||
self.store.register(user_id=user2, token="456", password_hash=None) | |||
self.pump() | |||
now = int(self.hs.get_clock().time_msec()) | |||
yield self.store.user_add_threepid(user1, "email", user1_email, now, now) | |||
yield self.store.user_add_threepid(user2, "email", user2_email, now, now) | |||
yield self.store.initialise_reserved_users(threepids) | |||
self.store.user_add_threepid(user1, "email", user1_email, now, now) | |||
self.store.user_add_threepid(user2, "email", user2_email, now, now) | |||
self.store.initialise_reserved_users(threepids) | |||
self.pump() | |||
active_count = yield self.store.get_monthly_active_count() | |||
active_count = self.store.get_monthly_active_count() | |||
# Test total counts | |||
self.assertEquals(active_count, user_num) | |||
self.assertEquals(self.get_success(active_count), user_num) | |||
# Test user is marked as active | |||
timestamp = yield self.store.user_last_seen_monthly_active(user1) | |||
self.assertTrue(timestamp) | |||
timestamp = yield self.store.user_last_seen_monthly_active(user2) | |||
self.assertTrue(timestamp) | |||
timestamp = self.store.user_last_seen_monthly_active(user1) | |||
self.assertTrue(self.get_success(timestamp)) | |||
timestamp = self.store.user_last_seen_monthly_active(user2) | |||
self.assertTrue(self.get_success(timestamp)) | |||
# Test that users are never removed from the db. | |||
self.hs.config.max_mau_value = 0 | |||
self.hs.get_clock().advance_time(FORTY_DAYS) | |||
self.reactor.advance(FORTY_DAYS) | |||
yield self.store.reap_monthly_active_users() | |||
self.store.reap_monthly_active_users() | |||
self.pump() | |||
active_count = yield self.store.get_monthly_active_count() | |||
self.assertEquals(active_count, user_num) | |||
active_count = self.store.get_monthly_active_count() | |||
self.assertEquals(self.get_success(active_count), user_num) | |||
# Test that regalar users are removed from the db | |||
ru_count = 2 | |||
yield self.store.upsert_monthly_active_user("@ru1:server") | |||
yield self.store.upsert_monthly_active_user("@ru2:server") | |||
active_count = yield self.store.get_monthly_active_count() | |||
self.store.upsert_monthly_active_user("@ru1:server") | |||
self.store.upsert_monthly_active_user("@ru2:server") | |||
self.pump() | |||
self.assertEqual(active_count, user_num + ru_count) | |||
active_count = self.store.get_monthly_active_count() | |||
self.assertEqual(self.get_success(active_count), user_num + ru_count) | |||
self.hs.config.max_mau_value = user_num | |||
yield self.store.reap_monthly_active_users() | |||
self.store.reap_monthly_active_users() | |||
self.pump() | |||
active_count = yield self.store.get_monthly_active_count() | |||
self.assertEquals(active_count, user_num) | |||
active_count = self.store.get_monthly_active_count() | |||
self.assertEquals(self.get_success(active_count), user_num) | |||
@defer.inlineCallbacks | |||
def test_can_insert_and_count_mau(self): | |||
count = yield self.store.get_monthly_active_count() | |||
self.assertEqual(0, count) | |||
count = self.store.get_monthly_active_count() | |||
self.assertEqual(0, self.get_success(count)) | |||
yield self.store.upsert_monthly_active_user("@user:server") | |||
count = yield self.store.get_monthly_active_count() | |||
self.store.upsert_monthly_active_user("@user:server") | |||
self.pump() | |||
self.assertEqual(1, count) | |||
count = self.store.get_monthly_active_count() | |||
self.assertEqual(1, self.get_success(count)) | |||
@defer.inlineCallbacks | |||
def test_user_last_seen_monthly_active(self): | |||
user_id1 = "@user1:server" | |||
user_id2 = "@user2:server" | |||
user_id3 = "@user3:server" | |||
result = yield self.store.user_last_seen_monthly_active(user_id1) | |||
self.assertFalse(result == 0) | |||
yield self.store.upsert_monthly_active_user(user_id1) | |||
yield self.store.upsert_monthly_active_user(user_id2) | |||
result = yield self.store.user_last_seen_monthly_active(user_id1) | |||
self.assertTrue(result > 0) | |||
result = yield self.store.user_last_seen_monthly_active(user_id3) | |||
self.assertFalse(result == 0) | |||
result = self.store.user_last_seen_monthly_active(user_id1) | |||
self.assertFalse(self.get_success(result) == 0) | |||
self.store.upsert_monthly_active_user(user_id1) | |||
self.store.upsert_monthly_active_user(user_id2) | |||
self.pump() | |||
result = self.store.user_last_seen_monthly_active(user_id1) | |||
self.assertGreater(self.get_success(result), 0) | |||
result = self.store.user_last_seen_monthly_active(user_id3) | |||
self.assertNotEqual(self.get_success(result), 0) | |||
@defer.inlineCallbacks | |||
def test_reap_monthly_active_users(self): | |||
self.hs.config.max_mau_value = 5 | |||
initial_users = 10 | |||
for i in range(initial_users): | |||
yield self.store.upsert_monthly_active_user("@user%d:server" % i) | |||
count = yield self.store.get_monthly_active_count() | |||
self.assertTrue(count, initial_users) | |||
yield self.store.reap_monthly_active_users() | |||
count = yield self.store.get_monthly_active_count() | |||
self.assertEquals(count, initial_users - self.hs.config.max_mau_value) | |||
self.hs.get_clock().advance_time(FORTY_DAYS) | |||
yield self.store.reap_monthly_active_users() | |||
count = yield self.store.get_monthly_active_count() | |||
self.assertEquals(count, 0) | |||
self.store.upsert_monthly_active_user("@user%d:server" % i) | |||
self.pump() | |||
count = self.store.get_monthly_active_count() | |||
self.assertTrue(self.get_success(count), initial_users) | |||
self.store.reap_monthly_active_users() | |||
self.pump() | |||
count = self.store.get_monthly_active_count() | |||
self.assertEquals( | |||
self.get_success(count), initial_users - self.hs.config.max_mau_value | |||
) | |||
self.reactor.advance(FORTY_DAYS) | |||
self.store.reap_monthly_active_users() | |||
self.pump() | |||
count = self.store.get_monthly_active_count() | |||
self.assertEquals(self.get_success(count), 0) |
@@ -16,19 +16,18 @@ | |||
from twisted.internet import defer | |||
from synapse.storage.presence import PresenceStore | |||
from synapse.types import UserID | |||
from tests import unittest | |||
from tests.utils import MockClock, setup_test_homeserver | |||
from tests.utils import setup_test_homeserver | |||
class PresenceStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) | |||
hs = yield setup_test_homeserver(self.addCleanup) | |||
self.store = PresenceStore(None, hs) | |||
self.store = hs.get_datastore() | |||
self.u_apple = UserID.from_string("@apple:test") | |||
self.u_banana = UserID.from_string("@banana:test") | |||
@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase): | |||
def setUp(self): | |||
hs = yield setup_test_homeserver(self.addCleanup) | |||
self.store = ProfileStore(None, hs) | |||
self.store = ProfileStore(hs.get_db_conn(), hs) | |||
self.u_frank = UserID.from_string("@frank:test") | |||
@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def setUp(self): | |||
self.hs = yield setup_test_homeserver(self.addCleanup) | |||
self.store = UserDirectoryStore(None, self.hs) | |||
self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) | |||
# alice and bob are both in !room_id. bobby is not but shares | |||
# a homeserver with alice. | |||
@@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||
events_to_filter.append(evt) | |||
# the erasey user gets erased | |||
self.hs.get_datastore().mark_user_erased("@erased:local_hs") | |||
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") | |||
# ... and the filtering happens. | |||
filtered = yield filter_events_for_server( | |||
@@ -22,6 +22,7 @@ from canonicaljson import json | |||
import twisted | |||
import twisted.logger | |||
from twisted.internet.defer import Deferred | |||
from twisted.trial import unittest | |||
from synapse.http.server import JsonResource | |||
@@ -281,12 +282,14 @@ class HomeserverTestCase(TestCase): | |||
kwargs.update(self._hs_args) | |||
return setup_test_homeserver(self.addCleanup, *args, **kwargs) | |||
def pump(self): | |||
def pump(self, by=0.0): | |||
""" | |||
Pump the reactor enough that Deferreds will fire. | |||
""" | |||
self.reactor.pump([0.0] * 100) | |||
self.reactor.pump([by] * 100) | |||
def get_success(self, d): | |||
if not isinstance(d, Deferred): | |||
return d | |||
self.pump() | |||
return self.successResultOf(d) |
@@ -30,8 +30,8 @@ from synapse.config.server import ServerConfig | |||
from synapse.federation.transport import server | |||
from synapse.http.server import HttpServer | |||
from synapse.server import HomeServer | |||
from synapse.storage import DataStore, PostgresEngine | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage import DataStore | |||
from synapse.storage.engines import PostgresEngine, create_engine | |||
from synapse.storage.prepare_database import ( | |||
_get_or_create_schema_state, | |||
_setup_new_database, | |||
@@ -42,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter | |||
# set this to True to run the tests against postgres instead of sqlite. | |||
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) | |||
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) | |||
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") | |||
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) | |||
@@ -244,8 +245,9 @@ def setup_test_homeserver( | |||
cur.close() | |||
db_conn.close() | |||
# Register the cleanup hook | |||
cleanup_func(cleanup) | |||
if not LEAVE_DB: | |||
# Register the cleanup hook | |||
cleanup_func(cleanup) | |||
hs.setup() | |||
else: | |||