@@ -0,0 +1 @@ | |||
Rename storage layer objects to be more sensible. |
@@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server. | |||
The directory info is stored in various tables, which can (typically after | |||
DB corruption) get stale or out of sync. If this happens, for now the | |||
solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) | |||
solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql) | |||
and then restart synapse. This should then start a background task to | |||
flush the current tables and regenerate the directory. |
@@ -40,7 +40,7 @@ class MockHomeserver(HomeServer): | |||
config.server_name, reactor=reactor, config=config, **kwargs | |||
) | |||
self.version_string = "Synapse/"+get_version_string(synapse) | |||
self.version_string = "Synapse/" + get_version_string(synapse) | |||
if __name__ == "__main__": | |||
@@ -86,7 +86,7 @@ if __name__ == "__main__": | |||
store = hs.get_datastore() | |||
async def run_background_updates(): | |||
await store.db.updates.run_background_updates(sleep=False) | |||
await store.db_pool.updates.run_background_updates(sleep=False) | |||
# Stop the reactor to exit the script once every background update is run. | |||
reactor.stop() | |||
@@ -35,31 +35,29 @@ from synapse.logging.context import ( | |||
make_deferred_yieldable, | |||
run_in_background, | |||
) | |||
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.deviceinbox import ( | |||
DeviceInboxBackgroundUpdateStore, | |||
) | |||
from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.events_bg_updates import ( | |||
from synapse.storage.database import DatabasePool, make_conn | |||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore | |||
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore | |||
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore | |||
from synapse.storage.databases.main.events_bg_updates import ( | |||
EventsBackgroundUpdatesStore, | |||
) | |||
from synapse.storage.data_stores.main.media_repository import ( | |||
from synapse.storage.databases.main.media_repository import ( | |||
MediaRepositoryBackgroundUpdateStore, | |||
) | |||
from synapse.storage.data_stores.main.registration import ( | |||
from synapse.storage.databases.main.registration import ( | |||
RegistrationBackgroundUpdateStore, | |||
find_max_generated_user_id_localpart, | |||
) | |||
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore | |||
from synapse.storage.data_stores.main.stats import StatsStore | |||
from synapse.storage.data_stores.main.user_directory import ( | |||
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore | |||
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore | |||
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore | |||
from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore | |||
from synapse.storage.databases.main.stats import StatsStore | |||
from synapse.storage.databases.main.user_directory import ( | |||
UserDirectoryBackgroundUpdateStore, | |||
) | |||
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore | |||
from synapse.storage.database import Database, make_conn | |||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage.prepare_database import prepare_database | |||
from synapse.util import Clock | |||
@@ -175,14 +173,14 @@ class Store( | |||
StatsStore, | |||
): | |||
def execute(self, f, *args, **kwargs): | |||
return self.db.runInteraction(f.__name__, f, *args, **kwargs) | |||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) | |||
def execute_sql(self, sql, *args): | |||
def r(txn): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
return self.db.runInteraction("execute_sql", r) | |||
return self.db_pool.runInteraction("execute_sql", r) | |||
def insert_many_txn(self, txn, table, headers, rows): | |||
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( | |||
@@ -227,7 +225,7 @@ class Porter(object): | |||
async def setup_table(self, table): | |||
if table in APPEND_ONLY_TABLES: | |||
# It's safe to just carry on inserting. | |||
row = await self.postgres_store.db.simple_select_one( | |||
row = await self.postgres_store.db_pool.simple_select_one( | |||
table="port_from_sqlite3", | |||
keyvalues={"table_name": table}, | |||
retcols=("forward_rowid", "backward_rowid"), | |||
@@ -244,7 +242,7 @@ class Porter(object): | |||
) = await self._setup_sent_transactions() | |||
backward_chunk = 0 | |||
else: | |||
await self.postgres_store.db.simple_insert( | |||
await self.postgres_store.db_pool.simple_insert( | |||
table="port_from_sqlite3", | |||
values={ | |||
"table_name": table, | |||
@@ -274,7 +272,7 @@ class Porter(object): | |||
await self.postgres_store.execute(delete_all) | |||
await self.postgres_store.db.simple_insert( | |||
await self.postgres_store.db_pool.simple_insert( | |||
table="port_from_sqlite3", | |||
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, | |||
) | |||
@@ -318,7 +316,7 @@ class Porter(object): | |||
if table == "user_directory_stream_pos": | |||
# We need to make sure there is a single row, `(X, null), as that is | |||
# what synapse expects to be there. | |||
await self.postgres_store.db.simple_insert( | |||
await self.postgres_store.db_pool.simple_insert( | |||
table=table, values={"stream_id": None} | |||
) | |||
self.progress.update(table, table_size) # Mark table as done | |||
@@ -359,7 +357,7 @@ class Porter(object): | |||
return headers, forward_rows, backward_rows | |||
headers, frows, brows = await self.sqlite_store.db.runInteraction( | |||
headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( | |||
"select", r | |||
) | |||
@@ -375,7 +373,7 @@ class Porter(object): | |||
def insert(txn): | |||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) | |||
self.postgres_store.db.simple_update_one_txn( | |||
self.postgres_store.db_pool.simple_update_one_txn( | |||
txn, | |||
table="port_from_sqlite3", | |||
keyvalues={"table_name": table}, | |||
@@ -413,7 +411,7 @@ class Porter(object): | |||
return headers, rows | |||
headers, rows = await self.sqlite_store.db.runInteraction("select", r) | |||
headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) | |||
if rows: | |||
forward_chunk = rows[-1][0] + 1 | |||
@@ -451,7 +449,7 @@ class Porter(object): | |||
], | |||
) | |||
self.postgres_store.db.simple_update_one_txn( | |||
self.postgres_store.db_pool.simple_update_one_txn( | |||
txn, | |||
table="port_from_sqlite3", | |||
keyvalues={"table_name": "event_search"}, | |||
@@ -494,7 +492,7 @@ class Porter(object): | |||
db_conn, allow_outdated_version=allow_outdated_version | |||
) | |||
prepare_database(db_conn, engine, config=self.hs_config) | |||
store = Store(Database(hs, db_config, engine), db_conn, hs) | |||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) | |||
db_conn.commit() | |||
return store | |||
@@ -502,7 +500,7 @@ class Porter(object): | |||
async def run_background_updates_on_postgres(self): | |||
# Manually apply all background updates on the PostgreSQL database. | |||
postgres_ready = ( | |||
await self.postgres_store.db.updates.has_completed_background_updates() | |||
await self.postgres_store.db_pool.updates.has_completed_background_updates() | |||
) | |||
if not postgres_ready: | |||
@@ -511,9 +509,9 @@ class Porter(object): | |||
self.progress.set_state("Running background updates on PostgreSQL") | |||
while not postgres_ready: | |||
await self.postgres_store.db.updates.do_next_background_update(100) | |||
await self.postgres_store.db_pool.updates.do_next_background_update(100) | |||
postgres_ready = await ( | |||
self.postgres_store.db.updates.has_completed_background_updates() | |||
self.postgres_store.db_pool.updates.has_completed_background_updates() | |||
) | |||
async def run(self): | |||
@@ -534,7 +532,7 @@ class Porter(object): | |||
# Check if all background updates are done, abort if not. | |||
updates_complete = ( | |||
await self.sqlite_store.db.updates.has_completed_background_updates() | |||
await self.sqlite_store.db_pool.updates.has_completed_background_updates() | |||
) | |||
if not updates_complete: | |||
end_error = ( | |||
@@ -576,22 +574,24 @@ class Porter(object): | |||
) | |||
try: | |||
await self.postgres_store.db.runInteraction("alter_table", alter_table) | |||
await self.postgres_store.db_pool.runInteraction( | |||
"alter_table", alter_table | |||
) | |||
except Exception: | |||
# On Error Resume Next | |||
pass | |||
await self.postgres_store.db.runInteraction( | |||
await self.postgres_store.db_pool.runInteraction( | |||
"create_port_table", create_port_table | |||
) | |||
# Step 2. Get tables. | |||
self.progress.set_state("Fetching tables") | |||
sqlite_tables = await self.sqlite_store.db.simple_select_onecol( | |||
sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( | |||
table="sqlite_master", keyvalues={"type": "table"}, retcol="name" | |||
) | |||
postgres_tables = await self.postgres_store.db.simple_select_onecol( | |||
postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( | |||
table="information_schema.tables", | |||
keyvalues={}, | |||
retcol="distinct table_name", | |||
@@ -692,7 +692,7 @@ class Porter(object): | |||
return headers, [r for r in rows if r[ts_ind] < yesterday] | |||
headers, rows = await self.sqlite_store.db.runInteraction("select", r) | |||
headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) | |||
rows = self._convert_rows("sent_transactions", headers, rows) | |||
@@ -725,7 +725,7 @@ class Porter(object): | |||
next_chunk = await self.sqlite_store.execute(get_start_id) | |||
next_chunk = max(max_inserted_rowid + 1, next_chunk) | |||
await self.postgres_store.db.simple_insert( | |||
await self.postgres_store.db_pool.simple_insert( | |||
table="port_from_sqlite3", | |||
values={ | |||
"table_name": "sent_transactions", | |||
@@ -794,14 +794,14 @@ class Porter(object): | |||
next_id = curr_id + 1 | |||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) | |||
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) | |||
return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) | |||
def _setup_user_id_seq(self): | |||
def r(txn): | |||
next_id = find_max_generated_user_id_localpart(txn) + 1 | |||
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) | |||
return self.postgres_store.db.runInteraction("setup_user_id_seq", r) | |||
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) | |||
############################################## | |||
@@ -268,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): | |||
# It is now safe to start your Synapse. | |||
hs.start_listening(listeners) | |||
hs.get_datastore().db.start_profiling() | |||
hs.get_datastore().db_pool.start_profiling() | |||
hs.get_pusherpool().start() | |||
setup_sentry(hs) | |||
@@ -125,15 +125,15 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet | |||
from synapse.rest.client.versions import VersionsRestServlet | |||
from synapse.rest.key.v2 import KeyApiV2Resource | |||
from synapse.server import HomeServer | |||
from synapse.storage.data_stores.main.censor_events import CensorEventsStore | |||
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore | |||
from synapse.storage.data_stores.main.monthly_active_users import ( | |||
from synapse.storage.databases.main.censor_events import CensorEventsStore | |||
from synapse.storage.databases.main.media_repository import MediaRepositoryStore | |||
from synapse.storage.databases.main.monthly_active_users import ( | |||
MonthlyActiveUsersWorkerStore, | |||
) | |||
from synapse.storage.data_stores.main.presence import UserPresenceState | |||
from synapse.storage.data_stores.main.search import SearchWorkerStore | |||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore | |||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore | |||
from synapse.storage.databases.main.presence import UserPresenceState | |||
from synapse.storage.databases.main.search import SearchWorkerStore | |||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore | |||
from synapse.storage.databases.main.user_directory import UserDirectoryStore | |||
from synapse.types import ReadReceipt | |||
from synapse.util.async_helpers import Linearizer | |||
from synapse.util.httpresourcetree import create_resource_tree | |||
@@ -441,7 +441,7 @@ def setup(config_options): | |||
_base.start(hs, config.listeners) | |||
hs.get_datastore().db.updates.start_doing_background_updates() | |||
hs.get_datastore().db_pool.updates.start_doing_background_updates() | |||
except Exception: | |||
# Print the exception and bail out. | |||
print("Error during startup:", file=sys.stderr) | |||
@@ -551,8 +551,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): | |||
# | |||
# This only reports info about the *main* database. | |||
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ | |||
stats["database_server_version"] = hs.get_datastore().db.engine.server_version | |||
stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ | |||
stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version | |||
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) | |||
try: | |||
@@ -100,7 +100,10 @@ class DatabaseConnectionConfig: | |||
self.name = name | |||
self.config = db_config | |||
self.data_stores = data_stores | |||
# The `data_stores` config is actually talking about `databases` (we | |||
# changed the name). | |||
self.databases = data_stores | |||
class DatabaseConfig(Config): | |||
@@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.types import StateMap | |||
if TYPE_CHECKING: | |||
from synapse.storage.data_stores.main import DataStore | |||
from synapse.storage.databases.main import DataStore | |||
@attr.s(slots=True) | |||
@@ -71,7 +71,7 @@ from synapse.replication.http.federation import ( | |||
) | |||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet | |||
from synapse.state import StateResolutionStore, resolve_events_with_store | |||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id | |||
from synapse.util.async_helpers import Linearizer, concurrently_execute | |||
from synapse.util.distributor import user_joined_room | |||
@@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator | |||
from synapse.logging.context import run_in_background | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet | |||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.state import StateFilter | |||
from synapse.types import ( | |||
Collection, | |||
@@ -38,7 +38,7 @@ from synapse.logging.utils import log_function | |||
from synapse.metrics import LaterGauge | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.state import StateHandler | |||
from synapse.storage.data_stores.main import DataStore | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.storage.presence import UserPresenceState | |||
from synapse.types import JsonDict, UserID, get_domain_from_id | |||
from synapse.util.async_helpers import Linearizer | |||
@@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler): | |||
is some spurious presence changes that will self-correct. | |||
""" | |||
# If the DB pool has already terminated, don't try updating | |||
if not self.store.db.is_running(): | |||
if not self.store.db_pool.is_running(): | |||
return | |||
logger.info( | |||
@@ -219,7 +219,7 @@ class ModuleApi(object): | |||
Returns: | |||
Deferred[object]: result of func | |||
""" | |||
return self._store.db.runInteraction(desc, func, *args, **kwargs) | |||
return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) | |||
def complete_sso_login( | |||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str | |||
@@ -16,8 +16,8 @@ | |||
import logging | |||
from typing import Optional | |||
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) | |||
class BaseSlavedStore(CacheInvalidationWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(BaseSlavedStore, self).__init__(database, db_conn, hs) | |||
if isinstance(self.database_engine, PostgresEngine): | |||
self._cache_id_gen = MultiWriterIdGenerator( | |||
@@ -17,13 +17,13 @@ | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream | |||
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore | |||
from synapse.storage.data_stores.main.tags import TagsWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore | |||
from synapse.storage.databases.main.tags import TagsWorkerStore | |||
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self._account_data_id_gen = SlavedIdTracker( | |||
db_conn, | |||
"account_data", | |||
@@ -14,7 +14,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.appservice import ( | |||
from synapse.storage.databases.main.appservice import ( | |||
ApplicationServiceTransactionWorkerStore, | |||
ApplicationServiceWorkerStore, | |||
) | |||
@@ -13,15 +13,15 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY | |||
from synapse.util.caches.descriptors import Cache | |||
from ._base import BaseSlavedStore | |||
class SlavedClientIpStore(BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedClientIpStore, self).__init__(database, db_conn, hs) | |||
self.client_ip_last_seen = Cache( | |||
@@ -16,14 +16,14 @@ | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams import ToDeviceStream | |||
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) | |||
self._device_inbox_id_gen = SlavedIdTracker( | |||
db_conn, "device_inbox", "stream_id" | |||
@@ -16,14 +16,14 @@ | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream | |||
from synapse.storage.data_stores.main.devices import DeviceWorkerStore | |||
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.devices import DeviceWorkerStore | |||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedDeviceStore, self).__init__(database, db_conn, hs) | |||
self.hs = hs | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.directory import DirectoryWorkerStore | |||
from synapse.storage.databases.main.directory import DirectoryWorkerStore | |||
from ._base import BaseSlavedStore | |||
@@ -15,18 +15,18 @@ | |||
# limitations under the License. | |||
import logging | |||
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore | |||
from synapse.storage.data_stores.main.event_push_actions import ( | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore | |||
from synapse.storage.databases.main.event_push_actions import ( | |||
EventPushActionsWorkerStore, | |||
) | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.data_stores.main.relations import RelationsWorkerStore | |||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore | |||
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore | |||
from synapse.storage.data_stores.main.state import StateGroupWorkerStore | |||
from synapse.storage.data_stores.main.stream import StreamWorkerStore | |||
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.databases.main.relations import RelationsWorkerStore | |||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore | |||
from synapse.storage.databases.main.signatures import SignatureWorkerStore | |||
from synapse.storage.databases.main.state import StateGroupWorkerStore | |||
from synapse.storage.databases.main.stream import StreamWorkerStore | |||
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
from ._base import BaseSlavedStore | |||
@@ -55,11 +55,11 @@ class SlavedEventStore( | |||
RelationsWorkerStore, | |||
BaseSlavedStore, | |||
): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedEventStore, self).__init__(database, db_conn, hs) | |||
events_max = self._stream_id_gen.get_current_token() | |||
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( | |||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"current_state_delta_stream", | |||
entity_column="room_id", | |||
@@ -13,14 +13,14 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.filtering import FilteringStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.filtering import FilteringStore | |||
from ._base import BaseSlavedStore | |||
class SlavedFilteringStore(BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedFilteringStore, self).__init__(database, db_conn, hs) | |||
# Filters are immutable so this cache doesn't need to be expired | |||
@@ -16,13 +16,13 @@ | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams import GroupServerStream | |||
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.group_server import GroupServerWorkerStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) | |||
self.hs = hs | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.keys import KeyStore | |||
from synapse.storage.databases.main.keys import KeyStore | |||
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that | |||
# the races it creates aren't too bad. | |||
@@ -15,8 +15,8 @@ | |||
from synapse.replication.tcp.streams import PresenceStream | |||
from synapse.storage import DataStore | |||
from synapse.storage.data_stores.main.presence import PresenceStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.presence import PresenceStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
from ._base import BaseSlavedStore | |||
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker | |||
class SlavedPresenceStore(BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedPresenceStore, self).__init__(database, db_conn, hs) | |||
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.storage.data_stores.main.profile import ProfileWorkerStore | |||
from synapse.storage.databases.main.profile import ProfileWorkerStore | |||
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): | |||
@@ -15,7 +15,7 @@ | |||
# limitations under the License. | |||
from synapse.replication.tcp.streams import PushRulesStream | |||
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore | |||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore | |||
from .events import SlavedEventStore | |||
@@ -15,15 +15,15 @@ | |||
# limitations under the License. | |||
from synapse.replication.tcp.streams import PushersStream | |||
from synapse.storage.data_stores.main.pusher import PusherWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.pusher import PusherWorkerStore | |||
from ._base import BaseSlavedStore | |||
from ._slaved_id_tracker import SlavedIdTracker | |||
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(SlavedPusherStore, self).__init__(database, db_conn, hs) | |||
self._pushers_id_gen = SlavedIdTracker( | |||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] | |||
@@ -15,15 +15,15 @@ | |||
# limitations under the License. | |||
from synapse.replication.tcp.streams import ReceiptsStream | |||
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore | |||
from ._base import BaseSlavedStore | |||
from ._slaved_id_tracker import SlavedIdTracker | |||
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
# We instantiate this first as the ReceiptsWorkerStore constructor | |||
# needs to be able to call get_max_receipt_stream_id | |||
self._receipts_id_gen = SlavedIdTracker( | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.registration import RegistrationWorkerStore | |||
from synapse.storage.databases.main.registration import RegistrationWorkerStore | |||
from ._base import BaseSlavedStore | |||
@@ -14,15 +14,15 @@ | |||
# limitations under the License. | |||
from synapse.replication.tcp.streams import PublicRoomsStream | |||
from synapse.storage.data_stores.main.room import RoomWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.room import RoomWorkerStore | |||
from ._base import BaseSlavedStore | |||
from ._slaved_id_tracker import SlavedIdTracker | |||
class RoomStore(RoomWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomStore, self).__init__(database, db_conn, hs) | |||
self._public_room_id_gen = SlavedIdTracker( | |||
db_conn, "public_room_list_stream", "stream_id" | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.data_stores.main.transactions import TransactionStore | |||
from synapse.storage.databases.main.transactions import TransactionStore | |||
from ._base import BaseSlavedStore | |||
@@ -31,7 +31,7 @@ from synapse.rest.admin._base import ( | |||
assert_user_is_admin, | |||
historical_admin_path_patterns, | |||
) | |||
from synapse.storage.data_stores.main.room import RoomSortOrder | |||
from synapse.storage.databases.main.room import RoomSortOrder | |||
from synapse.types import RoomAlias, RoomID, UserID, create_requester | |||
logger = logging.getLogger(__name__) | |||
@@ -586,7 +586,7 @@ class PreviewUrlResource(DirectServeJsonResource): | |||
logger.debug("Running url preview cache expiry") | |||
if not (await self.store.db.updates.has_completed_background_updates()): | |||
if not (await self.store.db_pool.updates.has_completed_background_updates()): | |||
logger.info("Still running DB updates; skipping expiry") | |||
return | |||
@@ -105,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import ( | |||
WorkerServerNoticesSender, | |||
) | |||
from synapse.state import StateHandler, StateResolutionHandler | |||
from synapse.storage import DataStore, DataStores, Storage | |||
from synapse.storage import Databases, DataStore, Storage | |||
from synapse.streams.events import EventSources | |||
from synapse.util import Clock | |||
from synapse.util.distributor import Distributor | |||
@@ -280,7 +280,7 @@ class HomeServer(object): | |||
def setup(self): | |||
logger.info("Setting up.") | |||
self.start_time = int(self.get_clock().time()) | |||
self.datastores = DataStores(self.DATASTORE_CLASS, self) | |||
self.datastores = Databases(self.DATASTORE_CLASS, self) | |||
logger.info("Finished setting up.") | |||
def setup_master(self): | |||
@@ -28,7 +28,7 @@ from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.logging.utils import log_function | |||
from synapse.state import v1, v2 | |||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.roommember import ProfileInfo | |||
from synapse.types import StateMap | |||
from synapse.util import Clock | |||
@@ -17,18 +17,19 @@ | |||
""" | |||
The storage layer is split up into multiple parts to allow Synapse to run | |||
against different configurations of databases (e.g. single or multiple | |||
databases). The `Database` class represents a single physical database. The | |||
`data_stores` are classes that talk directly to a `Database` instance and have | |||
associated schemas, background updates, etc. On top of those there are classes | |||
that provide high level interfaces that combine calls to multiple `data_stores`. | |||
databases). The `DatabasePool` class represents connections to a single physical | |||
database. The `databases` are classes that talk directly to a `DatabasePool` | |||
instance and have associated schemas, background updates, etc. On top of those | |||
there are classes that provide high level interfaces that combine calls to | |||
multiple `databases`. | |||
There are also schemas that get applied to every database, regardless of the | |||
data stores associated with them (e.g. the schema version tables), which are | |||
stored in `synapse.storage.schema`. | |||
""" | |||
from synapse.storage.data_stores import DataStores | |||
from synapse.storage.data_stores.main import DataStore | |||
from synapse.storage.databases import Databases | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.storage.persist_events import EventsPersistenceStorage | |||
from synapse.storage.purge_events import PurgeEventsStorage | |||
from synapse.storage.state import StateGroupStorage | |||
@@ -40,7 +41,7 @@ class Storage(object): | |||
"""The high level interfaces for talking to various storage layers. | |||
""" | |||
def __init__(self, hs, stores: DataStores): | |||
def __init__(self, hs, stores: Databases): | |||
# We include the main data store here mainly so that we don't have to | |||
# rewrite all the existing code to split it into high vs low level | |||
# interfaces. | |||
@@ -23,7 +23,7 @@ from canonicaljson import json | |||
from synapse.storage.database import LoggingTransaction # noqa: F401 | |||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.types import Collection, get_domain_from_id | |||
logger = logging.getLogger(__name__) | |||
@@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
per data store (and not one per physical database). | |||
""" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self.hs = hs | |||
self._clock = hs.get_clock() | |||
self.database_engine = database.engine | |||
self.db = database | |||
self.db_pool = database | |||
self.rand = random.SystemRandom() | |||
def process_replication_rows(self, stream_name, instance_name, token, rows): | |||
@@ -88,7 +88,7 @@ class BackgroundUpdater(object): | |||
def __init__(self, hs, database): | |||
self._clock = hs.get_clock() | |||
self.db = database | |||
self.db_pool = database | |||
# if a background update is currently running, its name. | |||
self._current_background_update = None # type: Optional[str] | |||
@@ -139,7 +139,7 @@ class BackgroundUpdater(object): | |||
# otherwise, check if there are updates to be run. This is important, | |||
# as we may be running on a worker which doesn't perform the bg updates | |||
# itself, but still wants to wait for them to happen. | |||
updates = await self.db.simple_select_onecol( | |||
updates = await self.db_pool.simple_select_onecol( | |||
"background_updates", | |||
keyvalues=None, | |||
retcol="1", | |||
@@ -160,7 +160,7 @@ class BackgroundUpdater(object): | |||
if update_name == self._current_background_update: | |||
return False | |||
update_exists = await self.db.simple_select_one_onecol( | |||
update_exists = await self.db_pool.simple_select_one_onecol( | |||
"background_updates", | |||
keyvalues={"update_name": update_name}, | |||
retcol="1", | |||
@@ -189,10 +189,10 @@ class BackgroundUpdater(object): | |||
ORDER BY ordering, update_name | |||
""" | |||
) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
if not self._current_background_update: | |||
all_pending_updates = await self.db.runInteraction( | |||
all_pending_updates = await self.db_pool.runInteraction( | |||
"background_updates", get_background_updates_txn, | |||
) | |||
if not all_pending_updates: | |||
@@ -243,7 +243,7 @@ class BackgroundUpdater(object): | |||
else: | |||
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE | |||
progress_json = await self.db.simple_select_one_onecol( | |||
progress_json = await self.db_pool.simple_select_one_onecol( | |||
"background_updates", | |||
keyvalues={"update_name": update_name}, | |||
retcol="progress_json", | |||
@@ -402,7 +402,7 @@ class BackgroundUpdater(object): | |||
logger.debug("[SQL] %s", sql) | |||
c.execute(sql) | |||
if isinstance(self.db.engine, engines.PostgresEngine): | |||
if isinstance(self.db_pool.engine, engines.PostgresEngine): | |||
runner = create_index_psql | |||
elif psql_only: | |||
runner = None | |||
@@ -413,7 +413,7 @@ class BackgroundUpdater(object): | |||
def updater(progress, batch_size): | |||
if runner is not None: | |||
logger.info("Adding index %s to %s", index_name, table) | |||
yield self.db.runWithConnection(runner) | |||
yield self.db_pool.runWithConnection(runner) | |||
yield self._end_background_update(update_name) | |||
return 1 | |||
@@ -433,7 +433,7 @@ class BackgroundUpdater(object): | |||
% update_name | |||
) | |||
self._current_background_update = None | |||
return self.db.simple_delete_one( | |||
return self.db_pool.simple_delete_one( | |||
"background_updates", keyvalues={"update_name": update_name} | |||
) | |||
@@ -445,7 +445,7 @@ class BackgroundUpdater(object): | |||
progress: The progress of the update. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"background_update_progress", | |||
self._background_update_progress_txn, | |||
update_name, | |||
@@ -463,7 +463,7 @@ class BackgroundUpdater(object): | |||
progress_json = json.dumps(progress) | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
"background_updates", | |||
keyvalues={"update_name": update_name}, | |||
@@ -279,7 +279,7 @@ class PerformanceCounters(object): | |||
return top_n_counters | |||
class Database(object): | |||
class DatabasePool(object): | |||
"""Wraps a single physical database and connection pool. | |||
A single database may be used by multiple data stores. | |||
@@ -15,17 +15,17 @@ | |||
import logging | |||
from synapse.storage.data_stores.main.events import PersistEventsStore | |||
from synapse.storage.data_stores.state import StateGroupDataStore | |||
from synapse.storage.database import Database, make_conn | |||
from synapse.storage.database import DatabasePool, make_conn | |||
from synapse.storage.databases.main.events import PersistEventsStore | |||
from synapse.storage.databases.state import StateGroupDataStore | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage.prepare_database import prepare_database | |||
logger = logging.getLogger(__name__) | |||
class DataStores(object): | |||
"""The various data stores. | |||
class Databases(object): | |||
"""The various databases. | |||
These are low level interfaces to physical databases. | |||
@@ -51,12 +51,12 @@ class DataStores(object): | |||
engine.check_database(db_conn) | |||
prepare_database( | |||
db_conn, engine, hs.config, data_stores=database_config.data_stores, | |||
db_conn, engine, hs.config, databases=database_config.databases, | |||
) | |||
database = Database(hs, database_config, engine) | |||
database = DatabasePool(hs, database_config, engine) | |||
if "main" in database_config.data_stores: | |||
if "main" in database_config.databases: | |||
logger.info("Starting 'main' data store") | |||
# Sanity check we don't try and configure the main store on | |||
@@ -73,7 +73,7 @@ class DataStores(object): | |||
hs, database, self.main | |||
) | |||
if "state" in database_config.data_stores: | |||
if "state" in database_config.databases: | |||
logger.info("Starting 'state' data store") | |||
# Sanity check we don't try and configure the state store on |
@@ -21,7 +21,7 @@ import time | |||
from synapse.api.constants import PresenceState | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.util.id_generators import ( | |||
IdGenerator, | |||
@@ -119,7 +119,7 @@ class DataStore( | |||
CacheInvalidationWorkerStore, | |||
ServerMetricsStore, | |||
): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self.hs = hs | |||
self._clock = hs.get_clock() | |||
self.database_engine = database.engine | |||
@@ -174,7 +174,7 @@ class DataStore( | |||
self._presence_on_startup = self._get_active_presence(db_conn) | |||
presence_cache_prefill, min_presence_val = self.db.get_cache_dict( | |||
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"presence_stream", | |||
entity_column="user_id", | |||
@@ -188,7 +188,7 @@ class DataStore( | |||
) | |||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() | |||
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( | |||
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"device_inbox", | |||
entity_column="user_id", | |||
@@ -203,7 +203,7 @@ class DataStore( | |||
) | |||
# The federation outbox and the local device inbox uses the same | |||
# stream_id generator. | |||
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( | |||
device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"device_federation_outbox", | |||
entity_column="destination", | |||
@@ -229,7 +229,7 @@ class DataStore( | |||
) | |||
events_max = self._stream_id_gen.get_current_token() | |||
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( | |||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"current_state_delta_stream", | |||
entity_column="room_id", | |||
@@ -243,7 +243,7 @@ class DataStore( | |||
prefilled_cache=curr_state_delta_prefill, | |||
) | |||
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( | |||
_group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"local_group_updates", | |||
entity_column="user_id", | |||
@@ -282,7 +282,7 @@ class DataStore( | |||
txn = db_conn.cursor() | |||
txn.execute(sql, (PresenceState.OFFLINE,)) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
txn.close() | |||
for row in rows: | |||
@@ -295,7 +295,9 @@ class DataStore( | |||
Counts the number of users who used this homeserver in the last 24 hours. | |||
""" | |||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) | |||
return self.db.runInteraction("count_daily_users", self._count_users, yesterday) | |||
return self.db_pool.runInteraction( | |||
"count_daily_users", self._count_users, yesterday | |||
) | |||
def count_monthly_users(self): | |||
""" | |||
@@ -305,7 +307,7 @@ class DataStore( | |||
amongst other things, includes a 3 day grace period before a user counts. | |||
""" | |||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"count_monthly_users", self._count_users, thirty_days_ago | |||
) | |||
@@ -405,7 +407,7 @@ class DataStore( | |||
return results | |||
return self.db.runInteraction("count_r30_users", _count_r30_users) | |||
return self.db_pool.runInteraction("count_r30_users", _count_r30_users) | |||
def _get_start_of_day(self): | |||
""" | |||
@@ -470,7 +472,7 @@ class DataStore( | |||
# frequently | |||
self._last_user_visit_update = now | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"generate_user_daily_visits", _generate_user_daily_visits | |||
) | |||
@@ -481,7 +483,7 @@ class DataStore( | |||
Returns: | |||
defer.Deferred: resolves to list[dict[str, Any]] | |||
""" | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
table="users", | |||
keyvalues={}, | |||
retcols=[ | |||
@@ -543,10 +545,12 @@ class DataStore( | |||
where_clause | |||
) | |||
txn.execute(sql, args) | |||
users = self.db.cursor_to_dict(txn) | |||
users = self.db_pool.cursor_to_dict(txn) | |||
return users, count | |||
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) | |||
return self.db_pool.runInteraction( | |||
"get_users_paginate_txn", get_users_paginate_txn | |||
) | |||
def search_users(self, term): | |||
"""Function to search users list for one or more users with | |||
@@ -558,7 +562,7 @@ class DataStore( | |||
Returns: | |||
defer.Deferred: resolves to list[dict[str, Any]] | |||
""" | |||
return self.db.simple_search_list( | |||
return self.db_pool.simple_search_list( | |||
table="users", | |||
term=term, | |||
col="name", |
@@ -23,7 +23,7 @@ from canonicaljson import json | |||
from twisted.internet import defer | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.util.id_generators import StreamIdGenerator | |||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
@@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
# the abstract methods being implemented. | |||
__metaclass__ = abc.ABCMeta | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
account_max = self.get_max_account_data_stream_id() | |||
self._account_data_stream_cache = StreamChangeCache( | |||
"AccountDataAndTagsChangeCache", account_max | |||
@@ -69,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
""" | |||
def get_account_data_for_user_txn(txn): | |||
rows = self.db.simple_select_list_txn( | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"account_data", | |||
{"user_id": user_id}, | |||
@@ -80,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
row["account_data_type"]: db_to_json(row["content"]) for row in rows | |||
} | |||
rows = self.db.simple_select_list_txn( | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"room_account_data", | |||
{"user_id": user_id}, | |||
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
return global_account_data, by_room | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_account_data_for_user", get_account_data_for_user_txn | |||
) | |||
@@ -104,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred: A dict | |||
""" | |||
result = yield self.db.simple_select_one_onecol( | |||
result = yield self.db_pool.simple_select_one_onecol( | |||
table="account_data", | |||
keyvalues={"user_id": user_id, "account_data_type": data_type}, | |||
retcol="content", | |||
@@ -129,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
""" | |||
def get_account_data_for_room_txn(txn): | |||
rows = self.db.simple_select_list_txn( | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"room_account_data", | |||
{"user_id": user_id, "room_id": room_id}, | |||
@@ -140,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
row["account_data_type"]: db_to_json(row["content"]) for row in rows | |||
} | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_account_data_for_room", get_account_data_for_room_txn | |||
) | |||
@@ -158,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
""" | |||
def get_account_data_for_room_and_type_txn(txn): | |||
content_json = self.db.simple_select_one_onecol_txn( | |||
content_json = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="room_account_data", | |||
keyvalues={ | |||
@@ -172,7 +172,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
return db_to_json(content_json) if content_json else None | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn | |||
) | |||
@@ -202,7 +202,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
return txn.fetchall() | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_updated_global_account_data", get_updated_global_account_data_txn | |||
) | |||
@@ -232,7 +232,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
return txn.fetchall() | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_updated_room_account_data", get_updated_room_account_data_txn | |||
) | |||
@@ -277,7 +277,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
if not changed: | |||
return defer.succeed(({}, {})) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn | |||
) | |||
@@ -295,7 +295,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
class AccountDataStore(AccountDataWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self._account_data_id_gen = StreamIdGenerator( | |||
db_conn, | |||
"account_data_max_stream_id", | |||
@@ -333,7 +333,7 @@ class AccountDataStore(AccountDataWorkerStore): | |||
# no need to lock here as room_account_data has a unique constraint | |||
# on (user_id, room_id, account_data_type) so simple_upsert will | |||
# retry if there is a conflict. | |||
yield self.db.simple_upsert( | |||
yield self.db_pool.simple_upsert( | |||
desc="add_room_account_data", | |||
table="room_account_data", | |||
keyvalues={ | |||
@@ -379,7 +379,7 @@ class AccountDataStore(AccountDataWorkerStore): | |||
# no need to lock here as account_data has a unique constraint on | |||
# (user_id, account_data_type) so simple_upsert will retry if | |||
# there is a conflict. | |||
yield self.db.simple_upsert( | |||
yield self.db_pool.simple_upsert( | |||
desc="add_user_account_data", | |||
table="account_data", | |||
keyvalues={"user_id": user_id, "account_data_type": account_data_type}, | |||
@@ -427,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore): | |||
) | |||
txn.execute(update_max_id_sql, (next_id, next_id)) | |||
return self.db.runInteraction("update_account_data_max_stream_id", _update) | |||
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) |
@@ -23,8 +23,8 @@ from twisted.internet import defer | |||
from synapse.appservice import AppServiceTransaction | |||
from synapse.config.appservice import load_appservices | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
logger = logging.getLogger(__name__) | |||
@@ -49,7 +49,7 @@ def _make_exclusive_regex(services_cache): | |||
class ApplicationServiceWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self.services_cache = load_appservices( | |||
hs.hostname, hs.config.app_service_config_files | |||
) | |||
@@ -134,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
A Deferred which resolves to a list of ApplicationServices, which | |||
may be empty. | |||
""" | |||
results = yield self.db.simple_select_list( | |||
results = yield self.db_pool.simple_select_list( | |||
"application_services_state", {"state": state}, ["as_id"] | |||
) | |||
# NB: This assumes this class is linked with ApplicationServiceStore | |||
@@ -156,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
Returns: | |||
A Deferred which resolves to ApplicationServiceState. | |||
""" | |||
result = yield self.db.simple_select_one( | |||
result = yield self.db_pool.simple_select_one( | |||
"application_services_state", | |||
{"as_id": service.id}, | |||
["state"], | |||
@@ -176,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
Returns: | |||
A Deferred which resolves when the state was set successfully. | |||
""" | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
"application_services_state", {"as_id": service.id}, {"state": state} | |||
) | |||
@@ -217,7 +217,9 @@ class ApplicationServiceTransactionWorkerStore( | |||
) | |||
return AppServiceTransaction(service=service, id=new_txn_id, events=events) | |||
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) | |||
return self.db_pool.runInteraction( | |||
"create_appservice_txn", _create_appservice_txn | |||
) | |||
def complete_appservice_txn(self, txn_id, service): | |||
"""Completes an application service transaction. | |||
@@ -250,7 +252,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
) | |||
# Set current txn_id for AS to 'txn_id' | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
"application_services_state", | |||
{"as_id": service.id}, | |||
@@ -258,13 +260,13 @@ class ApplicationServiceTransactionWorkerStore( | |||
) | |||
# Delete txn | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
"application_services_txns", | |||
{"txn_id": txn_id, "as_id": service.id}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"complete_appservice_txn", _complete_appservice_txn | |||
) | |||
@@ -288,7 +290,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
" ORDER BY txn_id ASC LIMIT 1", | |||
(service.id,), | |||
) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
return None | |||
@@ -296,7 +298,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
return entry | |||
entry = yield self.db.runInteraction( | |||
entry = yield self.db_pool.runInteraction( | |||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn | |||
) | |||
@@ -326,7 +328,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"set_appservice_last_pos", set_appservice_last_pos_txn | |||
) | |||
@@ -355,7 +357,7 @@ class ApplicationServiceTransactionWorkerStore( | |||
return upper_bound, [row[1] for row in rows] | |||
upper_bound, event_ids = yield self.db.runInteraction( | |||
upper_bound, event_ids = yield self.db_pool.runInteraction( | |||
"get_new_events_for_appservice", get_new_events_for_appservice_txn | |||
) | |||
@@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import ( | |||
EventsStreamEventRow, | |||
) | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.util.iterutils import batch_iter | |||
@@ -39,7 +39,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" | |||
class CacheInvalidationWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super().__init__(database, db_conn, hs) | |||
self._instance_name = hs.get_instance_name() | |||
@@ -92,7 +92,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_updated_caches", get_all_updated_caches_txn | |||
) | |||
@@ -203,7 +203,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
return | |||
cache_func.invalidate(keys) | |||
await self.db.runInteraction( | |||
await self.db_pool.runInteraction( | |||
"invalidate_cache_and_stream", | |||
self._send_invalidation_to_replication, | |||
cache_func.__name__, | |||
@@ -288,7 +288,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
if keys is not None: | |||
keys = list(keys) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="cache_invalidation_stream_by_instance", | |||
values={ |
@@ -21,10 +21,10 @@ from twisted.internet import defer | |||
from synapse.events.utils import prune_event_dict | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.data_stores.main.events import encode_json | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.databases.main.events import encode_json | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) | |||
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs: "HomeServer"): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
def _censor_redactions(): | |||
@@ -56,7 +56,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
return | |||
if not ( | |||
await self.db.updates.has_completed_background_update( | |||
await self.db_pool.updates.has_completed_background_update( | |||
"redactions_have_censored_ts_idx" | |||
) | |||
): | |||
@@ -85,7 +85,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
LIMIT ? | |||
""" | |||
rows = await self.db.execute( | |||
rows = await self.db_pool.execute( | |||
"_censor_redactions_fetch", None, sql, before_ts, 100 | |||
) | |||
@@ -123,14 +123,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
if pruned_json: | |||
self._censor_event_txn(txn, event_id, pruned_json) | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="redactions", | |||
keyvalues={"event_id": redaction_id}, | |||
updatevalues={"have_censored": True}, | |||
) | |||
await self.db.runInteraction("_update_censor_txn", _update_censor_txn) | |||
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) | |||
def _censor_event_txn(self, txn, event_id, pruned_json): | |||
"""Censor an event by replacing its JSON in the event_json table with the | |||
@@ -141,7 +141,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
event_id (str): The ID of the event to censor. | |||
pruned_json (str): The pruned JSON | |||
""" | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="event_json", | |||
keyvalues={"event_id": event_id}, | |||
@@ -193,7 +193,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
txn, "_get_event_cache", (event.event_id,) | |||
) | |||
yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) | |||
yield self.db_pool.runInteraction( | |||
"delete_expired_event", delete_expired_event_txn | |||
) | |||
def _delete_event_expiry_txn(self, txn, event_id): | |||
"""Delete the expiry timestamp associated with an event ID without deleting the | |||
@@ -203,6 +205,6 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
txn (LoggingTransaction): The transaction to use to perform the deletion. | |||
event_id (str): The event ID to delete the associated expiry timestamp of. | |||
""" | |||
return self.db.simple_delete_txn( | |||
return self.db_pool.simple_delete_txn( | |||
txn=txn, table="event_expiry", keyvalues={"event_id": event_id} | |||
) |
@@ -19,7 +19,7 @@ from twisted.internet import defer | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import Database, make_tuple_comparison_clause | |||
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause | |||
from synapse.util.caches.descriptors import Cache | |||
logger = logging.getLogger(__name__) | |||
@@ -31,40 +31,40 @@ LAST_SEEN_GRANULARITY = 120 * 1000 | |||
class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"user_ips_device_index", | |||
index_name="user_ips_device_id", | |||
table="user_ips", | |||
columns=["user_id", "device_id", "last_seen"], | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"user_ips_last_seen_index", | |||
index_name="user_ips_last_seen", | |||
table="user_ips", | |||
columns=["user_id", "last_seen"], | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"user_ips_last_seen_only_index", | |||
index_name="user_ips_last_seen_only", | |||
table="user_ips", | |||
columns=["last_seen"], | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"user_ips_analyze", self._analyze_user_ip | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"user_ips_remove_dupes", self._remove_user_ip_dupes | |||
) | |||
# Register a unique index | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"user_ips_device_unique_index", | |||
index_name="user_ips_user_token_ip_unique_index", | |||
table="user_ips", | |||
@@ -73,12 +73,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
# Drop the old non-unique index | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique | |||
) | |||
# Update the last seen info in devices. | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"devices_last_seen", self._devices_last_seen_update | |||
) | |||
@@ -89,8 +89,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") | |||
txn.close() | |||
yield self.db.runWithConnection(f) | |||
yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") | |||
yield self.db_pool.runWithConnection(f) | |||
yield self.db_pool.updates._end_background_update( | |||
"user_ips_drop_nonunique_index" | |||
) | |||
return 1 | |||
@defer.inlineCallbacks | |||
@@ -104,9 +106,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
def user_ips_analyze(txn): | |||
txn.execute("ANALYZE user_ips") | |||
yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) | |||
yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) | |||
yield self.db.updates._end_background_update("user_ips_analyze") | |||
yield self.db_pool.updates._end_background_update("user_ips_analyze") | |||
return 1 | |||
@@ -138,7 +140,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
return None | |||
# Get a last seen that has roughly `batch_size` since `begin_last_seen` | |||
end_last_seen = yield self.db.runInteraction( | |||
end_last_seen = yield self.db_pool.runInteraction( | |||
"user_ips_dups_get_last_seen", get_last_seen | |||
) | |||
@@ -269,14 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
(user_id, access_token, ip, device_id, user_agent, last_seen), | |||
) | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} | |||
) | |||
yield self.db.runInteraction("user_ips_dups_remove", remove) | |||
yield self.db_pool.runInteraction("user_ips_dups_remove", remove) | |||
if last: | |||
yield self.db.updates._end_background_update("user_ips_remove_dupes") | |||
yield self.db_pool.updates._end_background_update("user_ips_remove_dupes") | |||
return batch_size | |||
@@ -336,7 +338,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
txn.execute_batch(sql, rows) | |||
_, _, _, user_id, device_id = rows[-1] | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, | |||
"devices_last_seen", | |||
{"last_user_id": user_id, "last_device_id": device_id}, | |||
@@ -344,18 +346,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): | |||
return len(rows) | |||
updated = yield self.db.runInteraction( | |||
updated = yield self.db_pool.runInteraction( | |||
"_devices_last_seen_update", _devices_last_seen_update_txn | |||
) | |||
if not updated: | |||
yield self.db.updates._end_background_update("devices_last_seen") | |||
yield self.db_pool.updates._end_background_update("devices_last_seen") | |||
return updated | |||
class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
self.client_ip_last_seen = Cache( | |||
name="client_ip_last_seen", keylen=4, max_entries=50000 | |||
@@ -403,18 +405,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
def _update_client_ips_batch(self): | |||
# If the DB pool has already terminated, don't try updating | |||
if not self.db.is_running(): | |||
if not self.db_pool.is_running(): | |||
return | |||
to_update = self._batch_row_update | |||
self._batch_row_update = {} | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update | |||
) | |||
def _update_client_ips_batch_txn(self, txn, to_update): | |||
if "user_ips" in self.db._unsafe_to_upsert_tables or ( | |||
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( | |||
not self.database_engine.can_native_upsert | |||
): | |||
self.database_engine.lock_table(txn, "user_ips") | |||
@@ -423,7 +425,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry | |||
try: | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="user_ips", | |||
keyvalues={ | |||
@@ -445,7 +447,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
# this is always an update rather than an upsert: the row should | |||
# already exist, and if it doesn't, that may be because it has been | |||
# deleted, and we don't want to re-create it. | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn, | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -477,7 +479,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
if device_id is not None: | |||
keyvalues["device_id"] = device_id | |||
res = yield self.db.simple_select_list( | |||
res = yield self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
@@ -510,7 +512,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
user_agent, _, last_seen = self._batch_row_update[key] | |||
results[(access_token, ip)] = (user_agent, last_seen) | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["access_token", "ip", "user_agent", "last_seen"], | |||
@@ -540,7 +542,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
# Nothing to do | |||
return | |||
if not await self.db.updates.has_completed_background_update( | |||
if not await self.db_pool.updates.has_completed_background_update( | |||
"devices_last_seen" | |||
): | |||
# Only start pruning if we have finished populating the devices | |||
@@ -573,4 +575,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): | |||
def _prune_old_user_ips_txn(txn): | |||
txn.execute(sql, (timestamp,)) | |||
await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) | |||
await self.db_pool.runInteraction( | |||
"_prune_old_user_ips", _prune_old_user_ips_txn | |||
) |
@@ -22,7 +22,7 @@ from twisted.internet import defer | |||
from synapse.logging.opentracing import log_kv, set_tag, trace | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
logger = logging.getLogger(__name__) | |||
@@ -70,7 +70,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
stream_pos = current_stream_id | |||
return messages, stream_pos | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_new_messages_for_device", get_new_messages_for_device_txn | |||
) | |||
@@ -110,7 +110,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (user_id, device_id, up_to_stream_id)) | |||
return txn.rowcount | |||
count = yield self.db.runInteraction( | |||
count = yield self.db_pool.runInteraction( | |||
"delete_messages_for_device", delete_messages_for_device_txn | |||
) | |||
@@ -179,7 +179,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
stream_pos = current_stream_id | |||
return messages, stream_pos | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_new_device_msgs_for_remote", | |||
get_new_messages_for_remote_destination_txn, | |||
) | |||
@@ -204,7 +204,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(sql, (destination, up_to_stream_id)) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn | |||
) | |||
@@ -269,7 +269,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_new_device_messages", get_all_new_device_messages_txn | |||
) | |||
@@ -277,17 +277,17 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"device_inbox_stream_index", | |||
index_name="device_inbox_stream_id_user_id", | |||
table="device_inbox", | |||
columns=["stream_id", "user_id"], | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox | |||
) | |||
@@ -298,9 +298,9 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") | |||
txn.close() | |||
yield self.db.runWithConnection(reindex_txn) | |||
yield self.db_pool.runWithConnection(reindex_txn) | |||
yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) | |||
yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) | |||
return 1 | |||
@@ -308,7 +308,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): | |||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(DeviceInboxStore, self).__init__(database, db_conn, hs) | |||
# Map of (user_id, device_id) to the last stream_id that has been | |||
@@ -360,7 +360,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) | |||
with self._device_inbox_id_gen.get_next() as stream_id: | |||
now_ms = self.clock.time_msec() | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id | |||
) | |||
for user_id in local_messages_by_user_then_device.keys(): | |||
@@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) | |||
# Check if we've already inserted a matching message_id for that | |||
# origin. This can happen if the origin doesn't receive our | |||
# acknowledgement from the first time we received the message. | |||
already_inserted = self.db.simple_select_one_txn( | |||
already_inserted = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="device_federation_inbox", | |||
keyvalues={"origin": origin, "message_id": message_id}, | |||
@@ -392,7 +392,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) | |||
# Add an entry for this message_id so that we know we've processed | |||
# it. | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="device_federation_inbox", | |||
values={ | |||
@@ -410,7 +410,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) | |||
with self._device_inbox_id_gen.get_next() as stream_id: | |||
now_ms = self.clock.time_msec() | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_messages_from_remote_to_device_inbox", | |||
add_messages_txn, | |||
now_ms, |
@@ -31,7 +31,7 @@ from synapse.logging.opentracing import ( | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import ( | |||
Database, | |||
DatabasePool, | |||
LoggingTransaction, | |||
make_tuple_comparison_clause, | |||
) | |||
@@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
Raises: | |||
StoreError: if the device is not found | |||
""" | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
@@ -86,7 +86,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
containing "device_id", "user_id" and "display_name" for each | |||
device. | |||
""" | |||
devices = yield self.db.simple_select_list( | |||
devices = yield self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
@@ -118,7 +118,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
if not has_changed: | |||
return now_stream_id, [] | |||
updates = yield self.db.runInteraction( | |||
updates = yield self.db_pool.runInteraction( | |||
"get_device_updates_by_remote", | |||
self._get_device_updates_by_remote_txn, | |||
destination, | |||
@@ -255,7 +255,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
""" | |||
devices = ( | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_get_e2e_device_keys_txn", | |||
self._get_e2e_device_keys_txn, | |||
query_map.keys(), | |||
@@ -326,12 +326,12 @@ class DeviceWorkerStore(SQLBaseStore): | |||
rows = txn.fetchall() | |||
return rows[0][0] | |||
return self.db.runInteraction("get_last_device_update_for_remote_user", f) | |||
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) | |||
def mark_as_sent_devices_by_remote(self, destination, stream_id): | |||
"""Mark that updates have successfully been sent to the destination. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"mark_as_sent_devices_by_remote", | |||
self._mark_as_sent_devices_by_remote_txn, | |||
destination, | |||
@@ -350,7 +350,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (destination, stream_id)) | |||
rows = txn.fetchall() | |||
self.db.simple_upsert_many_txn( | |||
self.db_pool.simple_upsert_many_txn( | |||
txn=txn, | |||
table="device_lists_outbound_last_success", | |||
key_names=("destination", "user_id"), | |||
@@ -376,7 +376,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
""" | |||
with self._device_list_id_gen.get_next() as stream_id: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_user_sig_change_to_streams", | |||
self._add_user_signature_change_txn, | |||
from_user_id, | |||
@@ -391,7 +391,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
from_user_id, | |||
stream_id, | |||
) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"user_signature_stream", | |||
values={ | |||
@@ -449,7 +449,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
@cachedInlineCallbacks(num_args=2, tree=True) | |||
def _get_cached_user_device(self, user_id, device_id): | |||
content = yield self.db.simple_select_one_onecol( | |||
content = yield self.db_pool.simple_select_one_onecol( | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
retcol="content", | |||
@@ -459,7 +459,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
@cachedInlineCallbacks() | |||
def get_cached_devices_for_user(self, user_id): | |||
devices = yield self.db.simple_select_list( | |||
devices = yield self.db_pool.simple_select_list( | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("device_id", "content"), | |||
@@ -475,7 +475,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
Returns: | |||
(stream_id, devices) | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_devices_with_keys_by_user", | |||
self._get_devices_with_keys_by_user_txn, | |||
user_id, | |||
@@ -555,7 +555,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
return changes | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn | |||
) | |||
@@ -574,7 +574,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
SELECT DISTINCT user_ids FROM user_signature_stream | |||
WHERE from_user_id = ? AND stream_id > ? | |||
""" | |||
rows = yield self.db.execute( | |||
rows = yield self.db_pool.execute( | |||
"get_users_whose_signatures_changed", None, sql, user_id, from_key | |||
) | |||
return {user for row in rows for user in db_to_json(row[0])} | |||
@@ -631,7 +631,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_device_list_changes_for_remotes", | |||
_get_all_device_list_changes_for_remotes, | |||
) | |||
@@ -641,7 +641,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
"""Get the last stream_id we got for a user. May be None if we haven't | |||
got any information for them. | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="device_lists_remote_extremeties", | |||
keyvalues={"user_id": user_id}, | |||
retcol="stream_id", | |||
@@ -655,7 +655,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
inlineCallbacks=True, | |||
) | |||
def get_device_list_last_stream_id_for_remotes(self, user_ids): | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_extremeties", | |||
column="user_id", | |||
iterable=user_ids, | |||
@@ -680,7 +680,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
The IDs of users whose device lists need resync. | |||
""" | |||
if user_ids: | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_resync", | |||
column="user_id", | |||
iterable=user_ids, | |||
@@ -688,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
desc="get_user_ids_requiring_device_list_resync_with_iterable", | |||
) | |||
else: | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="device_lists_remote_resync", | |||
keyvalues=None, | |||
retcols=("user_id",), | |||
@@ -701,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
"""Records that the server has reason to believe the cache of the devices | |||
for the remote users is out of date. | |||
""" | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="device_lists_remote_resync", | |||
keyvalues={"user_id": user_id}, | |||
values={}, | |||
@@ -714,7 +714,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
""" | |||
def _mark_remote_user_device_list_as_unsubscribed_txn(txn): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="device_lists_remote_extremeties", | |||
keyvalues={"user_id": user_id}, | |||
@@ -723,17 +723,17 @@ class DeviceWorkerStore(SQLBaseStore): | |||
txn, self.get_device_list_last_stream_id_for_remote, (user_id,) | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"mark_remote_user_device_list_as_unsubscribed", | |||
_mark_remote_user_device_list_as_unsubscribed_txn, | |||
) | |||
class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"device_lists_stream_idx", | |||
index_name="device_lists_stream_user_id", | |||
table="device_lists_stream", | |||
@@ -741,7 +741,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
# create a unique index on device_lists_remote_cache | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"device_lists_remote_cache_unique_idx", | |||
index_name="device_lists_remote_cache_unique_id", | |||
table="device_lists_remote_cache", | |||
@@ -750,7 +750,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
# And one on device_lists_remote_extremeties | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"device_lists_remote_extremeties_unique_idx", | |||
index_name="device_lists_remote_extremeties_unique_idx", | |||
table="device_lists_remote_extremeties", | |||
@@ -759,22 +759,22 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
# once they complete, we can remove the old non-unique indexes. | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, | |||
self._drop_device_list_streams_non_unique_indexes, | |||
) | |||
# clear out duplicate device list outbound pokes | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, | |||
) | |||
# a pair of background updates that were added during the 1.14 release cycle, | |||
# but replaced with 58/06dlols_unique_idx.py | |||
self.db.updates.register_noop_background_update( | |||
self.db_pool.updates.register_noop_background_update( | |||
"device_lists_outbound_last_success_unique_idx", | |||
) | |||
self.db.updates.register_noop_background_update( | |||
self.db_pool.updates.register_noop_background_update( | |||
"drop_device_lists_outbound_last_success_non_unique_idx", | |||
) | |||
@@ -786,8 +786,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") | |||
txn.close() | |||
yield self.db.runWithConnection(f) | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.runWithConnection(f) | |||
yield self.db_pool.updates._end_background_update( | |||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES | |||
) | |||
return 1 | |||
@@ -807,7 +807,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
def _txn(txn): | |||
clause, args = make_tuple_comparison_clause( | |||
self.db.engine, [(x, last_row[x]) for x in KEY_COLS] | |||
self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] | |||
) | |||
sql = """ | |||
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts | |||
@@ -823,30 +823,32 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
",".join(KEY_COLS), # ORDER BY | |||
) | |||
txn.execute(sql, args + [batch_size]) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
row = None | |||
for row in rows: | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, | |||
) | |||
row["sent"] = False | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, "device_lists_outbound_pokes", row, | |||
) | |||
if row: | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, | |||
) | |||
return len(rows) | |||
rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) | |||
rows = await self.db_pool.runInteraction( | |||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn | |||
) | |||
if not rows: | |||
await self.db.updates._end_background_update( | |||
await self.db_pool.updates._end_background_update( | |||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES | |||
) | |||
@@ -854,7 +856,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(DeviceStore, self).__init__(database, db_conn, hs) | |||
# Map of (user_id, device_id) -> bool. If there is an entry that implies | |||
@@ -885,7 +887,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
return False | |||
try: | |||
inserted = yield self.db.simple_insert( | |||
inserted = yield self.db_pool.simple_insert( | |||
"devices", | |||
values={ | |||
"user_id": user_id, | |||
@@ -899,7 +901,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
if not inserted: | |||
# if the device already exists, check if it's a real device, or | |||
# if the device ID is reserved by something else | |||
hidden = yield self.db.simple_select_one_onecol( | |||
hidden = yield self.db_pool.simple_select_one_onecol( | |||
"devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
retcol="hidden", | |||
@@ -934,7 +936,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
Returns: | |||
defer.Deferred | |||
""" | |||
yield self.db.simple_delete_one( | |||
yield self.db_pool.simple_delete_one( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, | |||
desc="delete_device", | |||
@@ -952,7 +954,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
Returns: | |||
defer.Deferred | |||
""" | |||
yield self.db.simple_delete_many( | |||
yield self.db_pool.simple_delete_many( | |||
table="devices", | |||
column="device_id", | |||
iterable=device_ids, | |||
@@ -981,7 +983,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
updates["display_name"] = new_display_name | |||
if not updates: | |||
return defer.succeed(None) | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, | |||
updatevalues=updates, | |||
@@ -1005,7 +1007,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
Returns: | |||
Deferred[None] | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"update_remote_device_list_cache_entry", | |||
self._update_remote_device_list_cache_entry_txn, | |||
user_id, | |||
@@ -1018,7 +1020,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
self, txn, user_id, device_id, content, stream_id | |||
): | |||
if content.get("deleted"): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -1026,7 +1028,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) | |||
else: | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -1042,7 +1044,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) | |||
) | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="device_lists_remote_extremeties", | |||
keyvalues={"user_id": user_id}, | |||
@@ -1066,7 +1068,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
Returns: | |||
Deferred[None] | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"update_remote_device_list_cache", | |||
self._update_remote_device_list_cache_txn, | |||
user_id, | |||
@@ -1075,11 +1077,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
) | |||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} | |||
) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="device_lists_remote_cache", | |||
values=[ | |||
@@ -1098,7 +1100,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) | |||
) | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="device_lists_remote_extremeties", | |||
keyvalues={"user_id": user_id}, | |||
@@ -1111,7 +1113,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
# If we're replacing the remote user's device list cache presumably | |||
# we've done a full resync, so we remove the entry that says we need | |||
# to resync | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, | |||
) | |||
@@ -1124,7 +1126,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
return | |||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_device_change_to_stream", | |||
self._add_device_change_to_stream_txn, | |||
user_id, | |||
@@ -1139,7 +1141,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
with self._device_list_id_gen.get_next_mult( | |||
len(hosts) * len(device_ids) | |||
) as stream_ids: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_device_outbound_poke_to_stream", | |||
self._add_device_outbound_poke_to_stream_txn, | |||
user_id, | |||
@@ -1174,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
[(user_id, device_id, min_stream_id) for device_id in device_ids], | |||
) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="device_lists_stream", | |||
values=[ | |||
@@ -1196,7 +1198,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
now = self._clock.time_msec() | |||
next_stream_id = iter(stream_ids) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="device_lists_outbound_pokes", | |||
values=[ | |||
@@ -1303,7 +1305,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
return run_as_background_process( | |||
"prune_old_outbound_device_pokes", | |||
self.db.runInteraction, | |||
self.db_pool.runInteraction, | |||
"_prune_old_outbound_device_pokes", | |||
_prune_txn, | |||
) |
@@ -37,7 +37,7 @@ class DirectoryWorkerStore(SQLBaseStore): | |||
Deferred: results in namedtuple with keys "room_id" and | |||
"servers" or None if no association can be found | |||
""" | |||
room_id = yield self.db.simple_select_one_onecol( | |||
room_id = yield self.db_pool.simple_select_one_onecol( | |||
"room_aliases", | |||
{"room_alias": room_alias.to_string()}, | |||
"room_id", | |||
@@ -48,7 +48,7 @@ class DirectoryWorkerStore(SQLBaseStore): | |||
if not room_id: | |||
return None | |||
servers = yield self.db.simple_select_onecol( | |||
servers = yield self.db_pool.simple_select_onecol( | |||
"room_alias_servers", | |||
{"room_alias": room_alias.to_string()}, | |||
"server", | |||
@@ -61,7 +61,7 @@ class DirectoryWorkerStore(SQLBaseStore): | |||
return RoomAliasMapping(room_id, room_alias.to_string(), servers) | |||
def get_room_alias_creator(self, room_alias): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="room_aliases", | |||
keyvalues={"room_alias": room_alias}, | |||
retcol="creator", | |||
@@ -70,7 +70,7 @@ class DirectoryWorkerStore(SQLBaseStore): | |||
@cached(max_entries=5000) | |||
def get_aliases_for_room(self, room_id): | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
"room_aliases", | |||
{"room_id": room_id}, | |||
"room_alias", | |||
@@ -94,7 +94,7 @@ class DirectoryStore(DirectoryWorkerStore): | |||
""" | |||
def alias_txn(txn): | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"room_aliases", | |||
{ | |||
@@ -104,7 +104,7 @@ class DirectoryStore(DirectoryWorkerStore): | |||
}, | |||
) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="room_alias_servers", | |||
values=[ | |||
@@ -118,7 +118,7 @@ class DirectoryStore(DirectoryWorkerStore): | |||
) | |||
try: | |||
ret = yield self.db.runInteraction( | |||
ret = yield self.db_pool.runInteraction( | |||
"create_room_alias_association", alias_txn | |||
) | |||
except self.database_engine.module.IntegrityError: | |||
@@ -129,7 +129,7 @@ class DirectoryStore(DirectoryWorkerStore): | |||
@defer.inlineCallbacks | |||
def delete_room_alias(self, room_alias): | |||
room_id = yield self.db.runInteraction( | |||
room_id = yield self.db_pool.runInteraction( | |||
"delete_room_alias", self._delete_room_alias_txn, room_alias | |||
) | |||
@@ -190,6 +190,6 @@ class DirectoryStore(DirectoryWorkerStore): | |||
txn, self.get_aliases_for_room, (new_room_id,) | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"_update_aliases_for_room_txn", _update_aliases_for_room_txn | |||
) |
@@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
StoreError | |||
""" | |||
yield self.db.simple_update_one( | |||
yield self.db_pool.simple_update_one( | |||
table="e2e_room_keys", | |||
keyvalues={ | |||
"user_id": user_id, | |||
@@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
} | |||
) | |||
yield self.db.simple_insert_many( | |||
yield self.db_pool.simple_insert_many( | |||
table="e2e_room_keys", values=values, desc="add_e2e_room_keys" | |||
) | |||
@@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
if session_id: | |||
keyvalues["session_id"] = session_id | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="e2e_room_keys", | |||
keyvalues=keyvalues, | |||
retcols=( | |||
@@ -171,7 +171,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_e2e_room_keys_multi", | |||
self._get_e2e_room_keys_multi_txn, | |||
user_id, | |||
@@ -235,7 +235,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
version (str): the version ID of the backup we're querying about | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="e2e_room_keys", | |||
keyvalues={"user_id": user_id, "version": version}, | |||
retcol="COUNT(*)", | |||
@@ -268,7 +268,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
if session_id: | |||
keyvalues["session_id"] = session_id | |||
yield self.db.simple_delete( | |||
yield self.db_pool.simple_delete( | |||
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" | |||
) | |||
@@ -313,7 +313,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
# it isn't there. | |||
raise StoreError(404, "No row found") | |||
result = self.db.simple_select_one_txn( | |||
result = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="e2e_room_keys_versions", | |||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, | |||
@@ -325,7 +325,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
result["etag"] = 0 | |||
return result | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn | |||
) | |||
@@ -353,7 +353,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
new_version = str(int(current_version) + 1) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="e2e_room_keys_versions", | |||
values={ | |||
@@ -366,7 +366,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
return new_version | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn | |||
) | |||
@@ -392,7 +392,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
updatevalues["etag"] = version_etag | |||
if updatevalues: | |||
return self.db.simple_update( | |||
return self.db_pool.simple_update( | |||
table="e2e_room_keys_versions", | |||
keyvalues={"user_id": user_id, "version": version}, | |||
updatevalues=updatevalues, | |||
@@ -421,19 +421,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): | |||
else: | |||
this_version = version | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="e2e_room_keys", | |||
keyvalues={"user_id": user_id, "version": this_version}, | |||
) | |||
return self.db.simple_update_one_txn( | |||
return self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="e2e_room_keys_versions", | |||
keyvalues={"user_id": user_id, "version": this_version}, | |||
updatevalues={"deleted": 1}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn | |||
) |
@@ -51,7 +51,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
if not query_list: | |||
return {} | |||
results = yield self.db.runInteraction( | |||
results = yield self.db_pool.runInteraction( | |||
"get_e2e_device_keys", | |||
self._get_e2e_device_keys_txn, | |||
query_list, | |||
@@ -128,7 +128,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(sql, query_params) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
result = {} | |||
for row in rows: | |||
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(signature_sql, signature_query_params) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
# add each cross-signing signature to the correct device in the result dict. | |||
for row in rows: | |||
@@ -189,7 +189,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
key_id) to json string for key | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="e2e_one_time_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
@@ -222,7 +222,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
# a unique constraint. If there is a race of two calls to | |||
# `add_e2e_one_time_keys` then they'll conflict and we will only | |||
# insert one set. | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="e2e_one_time_keys_json", | |||
values=[ | |||
@@ -241,7 +241,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
txn, self.count_e2e_one_time_keys, (user_id, device_id) | |||
) | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys | |||
) | |||
@@ -264,7 +264,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
result[algorithm] = key_count | |||
return result | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"count_e2e_one_time_keys", _count_e2e_one_time_keys | |||
) | |||
@@ -318,7 +318,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
to None. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_bare_e2e_cross_signing_keys_bulk", | |||
self._get_bare_e2e_cross_signing_keys_bulk_txn, | |||
user_ids, | |||
@@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(sql, params) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
for row in rows: | |||
user_id = row["user_id"] | |||
@@ -420,7 +420,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
query_params.extend(item) | |||
txn.execute(sql, query_params) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
# and add the signatures to the appropriate keys | |||
for row in rows: | |||
@@ -470,7 +470,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) | |||
if from_user_id: | |||
result = yield self.db.runInteraction( | |||
result = yield self.db_pool.runInteraction( | |||
"get_e2e_cross_signing_signatures", | |||
self._get_e2e_cross_signing_signatures_txn, | |||
result, | |||
@@ -531,7 +531,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_user_signature_changes_for_remotes", | |||
_get_all_user_signature_changes_for_remotes_txn, | |||
) | |||
@@ -549,7 +549,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
set_tag("time_now", time_now) | |||
set_tag("device_keys", device_keys) | |||
old_key_json = self.db.simple_select_one_onecol_txn( | |||
old_key_json = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="e2e_device_keys_json", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -565,7 +565,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
log_kv({"Message": "Device key already stored."}) | |||
return False | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="e2e_device_keys_json", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -574,7 +574,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
log_kv({"message": "Device keys stored."}) | |||
return True | |||
return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) | |||
return self.db_pool.runInteraction( | |||
"set_e2e_device_keys", _set_e2e_device_keys_txn | |||
) | |||
def claim_e2e_one_time_keys(self, query_list): | |||
"""Take a list of one time keys out of the database""" | |||
@@ -613,7 +615,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
) | |||
return result | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys | |||
) | |||
@@ -626,12 +628,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
"user_id": user_id, | |||
} | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="e2e_device_keys_json", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="e2e_one_time_keys_json", | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
@@ -640,7 +642,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
txn, self.count_e2e_one_time_keys, (user_id, device_id) | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn | |||
) | |||
@@ -679,7 +681,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
# We only need to do this for local users, since remote servers should be | |||
# responsible for checking this for their own users. | |||
if self.hs.is_mine_id(user_id): | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"devices", | |||
values={ | |||
@@ -692,7 +694,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
# and finally, store the key itself | |||
with self._cross_signing_id_gen.get_next() as stream_id: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"e2e_cross_signing_keys", | |||
values={ | |||
@@ -715,7 +717,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
key_type (str): the type of cross-signing key to set | |||
key (dict): the key data | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"add_e2e_cross_signing_key", | |||
self._set_e2e_cross_signing_key_txn, | |||
user_id, | |||
@@ -730,7 +732,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): | |||
user_id (str): the user who made the signatures | |||
signatures (iterable[SignatureListItem]): signatures to add | |||
""" | |||
return self.db.simple_insert_many( | |||
return self.db_pool.simple_insert_many( | |||
"e2e_cross_signing_signatures", | |||
[ | |||
{ |
@@ -22,9 +22,9 @@ from twisted.internet import defer | |||
from synapse.api.errors import StoreError | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.databases.main.signatures import SignatureWorkerStore | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.iterutils import batch_iter | |||
@@ -65,7 +65,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
Returns: | |||
list of event_ids | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_auth_chain_ids", | |||
self._get_auth_chain_ids_txn, | |||
event_ids, | |||
@@ -114,7 +114,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
Deferred[Set[str]] | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_auth_chain_difference", | |||
self._get_auth_chain_difference_txn, | |||
state_sets, | |||
@@ -260,12 +260,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
return {eid for eid, n in event_to_missing_sets.items() if n} | |||
def get_oldest_events_in_room(self, room_id): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id | |||
) | |||
def get_oldest_events_with_depth_in_room(self, room_id): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_oldest_events_with_depth_in_room", | |||
self.get_oldest_events_with_depth_in_room_txn, | |||
room_id, | |||
@@ -296,7 +296,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
Returns | |||
Deferred[int] | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="events", | |||
column="event_id", | |||
iterable=event_ids, | |||
@@ -310,7 +310,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
return max(row["depth"] for row in rows) | |||
def _get_oldest_events_in_room_txn(self, txn, room_id): | |||
return self.db.simple_select_onecol_txn( | |||
return self.db_pool.simple_select_onecol_txn( | |||
txn, | |||
table="event_backward_extremities", | |||
keyvalues={"room_id": room_id}, | |||
@@ -332,7 +332,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id | |||
) | |||
@@ -387,13 +387,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
txn.execute(sql, query_args) | |||
return [room_id for room_id, in txn] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn | |||
) | |||
@cached(max_entries=5000, iterable=True) | |||
def get_latest_event_ids_in_room(self, room_id): | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="event_forward_extremities", | |||
keyvalues={"room_id": room_id}, | |||
retcol="event_id", | |||
@@ -403,12 +403,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
def get_min_depth(self, room_id): | |||
""" For hte given room, get the minimum depth we have seen for it. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_min_depth", self._get_min_depth_interaction, room_id | |||
) | |||
def _get_min_depth_interaction(self, txn, room_id): | |||
min_depth = self.db.simple_select_one_onecol_txn( | |||
min_depth = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="room_depth", | |||
keyvalues={"room_id": room_id}, | |||
@@ -474,7 +474,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
txn.execute(sql, (stream_ordering, room_id)) | |||
return [event_id for event_id, in txn] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn | |||
) | |||
@@ -489,7 +489,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
limit (int) | |||
""" | |||
return ( | |||
self.db.runInteraction( | |||
self.db_pool.runInteraction( | |||
"get_backfill_events", | |||
self._get_backfill_events, | |||
room_id, | |||
@@ -520,7 +520,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
queue = PriorityQueue() | |||
for event_id in event_list: | |||
depth = self.db.simple_select_one_onecol_txn( | |||
depth = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="events", | |||
keyvalues={"event_id": event_id, "room_id": room_id}, | |||
@@ -552,7 +552,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
@defer.inlineCallbacks | |||
def get_missing_events(self, room_id, earliest_events, latest_events, limit): | |||
ids = yield self.db.runInteraction( | |||
ids = yield self.db_pool.runInteraction( | |||
"get_missing_events", | |||
self._get_missing_events, | |||
room_id, | |||
@@ -605,7 +605,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas | |||
Returns: | |||
Deferred[list[str]] | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="event_edges", | |||
column="prev_event_id", | |||
iterable=event_ids, | |||
@@ -628,10 +628,10 @@ class EventFederationStore(EventFederationWorkerStore): | |||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(EventFederationStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth | |||
) | |||
@@ -658,13 +658,13 @@ class EventFederationStore(EventFederationWorkerStore): | |||
return run_as_background_process( | |||
"delete_old_forward_extrem_cache", | |||
self.db.runInteraction, | |||
self.db_pool.runInteraction, | |||
"_delete_old_forward_extrem_cache", | |||
_delete_old_forward_extrem_cache_txn, | |||
) | |||
def clean_room_for_join(self, room_id): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"clean_room_for_join", self._clean_room_for_join_txn, room_id | |||
) | |||
@@ -708,17 +708,19 @@ class EventFederationStore(EventFederationWorkerStore): | |||
"max_stream_id_exclusive": min_stream_id, | |||
} | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, self.EVENT_AUTH_STATE_ONLY, new_progress | |||
) | |||
return min_stream_id >= target_min_stream_id | |||
result = yield self.db.runInteraction( | |||
result = yield self.db_pool.runInteraction( | |||
self.EVENT_AUTH_STATE_ONLY, delete_event_auth | |||
) | |||
if not result: | |||
yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) | |||
yield self.db_pool.updates._end_background_update( | |||
self.EVENT_AUTH_STATE_ONLY | |||
) | |||
return batch_size |
@@ -21,7 +21,7 @@ from canonicaljson import json | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.util.caches.descriptors import cachedInlineCallbacks | |||
logger = logging.getLogger(__name__) | |||
@@ -66,7 +66,7 @@ def _deserialize_action(actions, is_highlight): | |||
class EventPushActionsWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) | |||
# These get correctly set by _find_stream_orderings_for_times_txn | |||
@@ -91,7 +91,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
def get_unread_event_push_actions_by_room_for_user( | |||
self, room_id, user_id, last_read_event_id | |||
): | |||
ret = yield self.db.runInteraction( | |||
ret = yield self.db_pool.runInteraction( | |||
"get_unread_event_push_actions_by_room", | |||
self._get_unread_counts_by_receipt_txn, | |||
room_id, | |||
@@ -176,7 +176,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) | |||
return [r[0] for r in txn] | |||
ret = await self.db.runInteraction("get_push_action_users_in_range", f) | |||
ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) | |||
return ret | |||
async def get_unread_push_actions_for_user_in_range_for_http( | |||
@@ -230,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
after_read_receipt = await self.db.runInteraction( | |||
after_read_receipt = await self.db_pool.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt | |||
) | |||
@@ -258,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
no_read_receipt = await self.db.runInteraction( | |||
no_read_receipt = await self.db_pool.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt | |||
) | |||
@@ -332,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
after_read_receipt = await self.db.runInteraction( | |||
after_read_receipt = await self.db_pool.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt | |||
) | |||
@@ -360,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
no_read_receipt = await self.db.runInteraction( | |||
no_read_receipt = await self.db_pool.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt | |||
) | |||
@@ -410,7 +410,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (user_id, min_stream_ordering)) | |||
return bool(txn.fetchone()) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_if_maybe_push_in_range_for_user", | |||
_get_if_maybe_push_in_range_for_user_txn, | |||
) | |||
@@ -461,7 +461,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
), | |||
) | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn | |||
) | |||
@@ -471,7 +471,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
""" | |||
try: | |||
res = await self.db.simple_delete( | |||
res = await self.db_pool.simple_delete( | |||
table="event_push_actions_staging", | |||
keyvalues={"event_id": event_id}, | |||
desc="remove_push_actions_from_staging", | |||
@@ -488,7 +488,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
def _find_stream_orderings_for_times(self): | |||
return run_as_background_process( | |||
"event_push_action_stream_orderings", | |||
self.db.runInteraction, | |||
self.db_pool.runInteraction, | |||
"_find_stream_orderings_for_times", | |||
self._find_stream_orderings_for_times_txn, | |||
) | |||
@@ -524,7 +524,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
Deferred[int]: stream ordering of the first event received on/after | |||
the timestamp | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"_find_first_stream_ordering_after_ts_txn", | |||
self._find_first_stream_ordering_after_ts_txn, | |||
ts, | |||
@@ -619,24 +619,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (stream_ordering,)) | |||
return txn.fetchone() | |||
result = await self.db.runInteraction("get_time_of_last_push_action_before", f) | |||
result = await self.db_pool.runInteraction( | |||
"get_time_of_last_push_action_before", f | |||
) | |||
return result[0] if result else None | |||
class EventPushActionsStore(EventPushActionsWorkerStore): | |||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(EventPushActionsStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
self.EPA_HIGHLIGHT_INDEX, | |||
index_name="event_push_actions_u_highlight", | |||
table="event_push_actions", | |||
columns=["user_id", "stream_ordering"], | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"event_push_actions_highlights_index", | |||
index_name="event_push_actions_highlights_index", | |||
table="event_push_actions", | |||
@@ -678,9 +680,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
" LIMIT ?" % (before_clause,) | |||
) | |||
txn.execute(sql, args) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
push_actions = await self.db.runInteraction("get_push_actions_for_user", f) | |||
push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) | |||
for pa in push_actions: | |||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) | |||
return push_actions | |||
@@ -690,7 +692,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") | |||
return txn.fetchone() | |||
result = await self.db.runInteraction( | |||
result = await self.db_pool.runInteraction( | |||
"get_latest_push_action_stream_ordering", f | |||
) | |||
return result[0] or 0 | |||
@@ -753,7 +755,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
while True: | |||
logger.info("Rotating notifications") | |||
caught_up = await self.db.runInteraction( | |||
caught_up = await self.db_pool.runInteraction( | |||
"_rotate_notifs", self._rotate_notifs_txn | |||
) | |||
if caught_up: | |||
@@ -767,7 +769,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
the archiving process has caught up or not. | |||
""" | |||
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( | |||
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="event_push_summary_stream_ordering", | |||
keyvalues={}, | |||
@@ -803,7 +805,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
return caught_up | |||
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): | |||
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( | |||
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="event_push_summary_stream_ordering", | |||
keyvalues={}, | |||
@@ -835,7 +837,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
# If the `old.user_id` above is NULL then we know there isn't already an | |||
# entry in the table, so we simply insert it. Otherwise we update the | |||
# existing table. | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_push_summary", | |||
values=[ |
@@ -32,8 +32,8 @@ from synapse.events import EventBase # noqa: F401 | |||
from synapse.events.snapshot import EventContext # noqa: F401 | |||
from synapse.logging.utils import log_function | |||
from synapse.storage._base import db_to_json, make_in_list_sql_clause | |||
from synapse.storage.data_stores.main.search import SearchEntry | |||
from synapse.storage.database import Database, LoggingTransaction | |||
from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.databases.main.search import SearchEntry | |||
from synapse.storage.util.id_generators import StreamIdGenerator | |||
from synapse.types import StateMap, get_domain_from_id | |||
from synapse.util.frozenutils import frozendict_json_encoder | |||
@@ -41,7 +41,7 @@ from synapse.util.iterutils import batch_iter | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
from synapse.storage.data_stores.main import DataStore | |||
from synapse.storage.databases.main import DataStore | |||
logger = logging.getLogger(__name__) | |||
@@ -132,9 +132,11 @@ class PersistEventsStore: | |||
Note: This is not part of the `DataStore` mixin. | |||
""" | |||
def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): | |||
def __init__( | |||
self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" | |||
): | |||
self.hs = hs | |||
self.db = db | |||
self.db_pool = db | |||
self.store = main_data_store | |||
self.database_engine = db.engine | |||
self._clock = hs.get_clock() | |||
@@ -207,7 +209,7 @@ class PersistEventsStore: | |||
for (event, context), stream in zip(events_and_contexts, stream_orderings): | |||
event.internal_metadata.stream_ordering = stream | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"persist_events", | |||
self._persist_events_txn, | |||
events_and_contexts=events_and_contexts, | |||
@@ -283,7 +285,7 @@ class PersistEventsStore: | |||
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) | |||
for chunk in batch_iter(event_ids, 100): | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk | |||
) | |||
@@ -347,7 +349,7 @@ class PersistEventsStore: | |||
existing_prevs.add(prev_event_id) | |||
for chunk in batch_iter(event_ids, 100): | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk | |||
) | |||
@@ -421,7 +423,7 @@ class PersistEventsStore: | |||
# event's auth chain, but its easier for now just to store them (and | |||
# it doesn't take much storage compared to storing the entire event | |||
# anyway). | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_auth", | |||
values=[ | |||
@@ -484,7 +486,7 @@ class PersistEventsStore: | |||
""" | |||
txn.execute(sql, (stream_id, room_id)) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table="current_state_events", keyvalues={"room_id": room_id}, | |||
) | |||
else: | |||
@@ -632,7 +634,7 @@ class PersistEventsStore: | |||
creator = content.get("creator") | |||
room_version_id = content.get("room_version", RoomVersions.V1.identifier) | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
@@ -644,14 +646,14 @@ class PersistEventsStore: | |||
self, txn, new_forward_extremities, max_stream_order | |||
): | |||
for room_id, new_extrem in new_forward_extremities.items(): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table="event_forward_extremities", keyvalues={"room_id": room_id} | |||
) | |||
txn.call_after( | |||
self.store.get_latest_event_ids_in_room.invalidate, (room_id,) | |||
) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_forward_extremities", | |||
values=[ | |||
@@ -664,7 +666,7 @@ class PersistEventsStore: | |||
# new stream_ordering to new forward extremeties in the room. | |||
# This allows us to later efficiently look up the forward extremeties | |||
# for a room before a given stream_ordering | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="stream_ordering_to_exterm", | |||
values=[ | |||
@@ -788,7 +790,7 @@ class PersistEventsStore: | |||
# change in outlier status to our workers. | |||
stream_order = event.internal_metadata.stream_ordering | |||
state_group_id = context.state_group | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="ex_outlier_stream", | |||
values={ | |||
@@ -826,7 +828,7 @@ class PersistEventsStore: | |||
d.pop("redacted_because", None) | |||
return d | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_json", | |||
values=[ | |||
@@ -843,7 +845,7 @@ class PersistEventsStore: | |||
], | |||
) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="events", | |||
values=[ | |||
@@ -873,7 +875,7 @@ class PersistEventsStore: | |||
# If we're persisting an unredacted event we go and ensure | |||
# that we mark any redactions that reference this event as | |||
# requiring censoring. | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn, | |||
table="redactions", | |||
keyvalues={"redacts": event.event_id}, | |||
@@ -1015,7 +1017,9 @@ class PersistEventsStore: | |||
state_values.append(vals) | |||
self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) | |||
self.db_pool.simple_insert_many_txn( | |||
txn, table="state_events", values=state_values | |||
) | |||
# Prefill the event cache | |||
self._add_to_cache(txn, events_and_contexts) | |||
@@ -1046,7 +1050,7 @@ class PersistEventsStore: | |||
) | |||
txn.execute(sql + clause, args) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
for row in rows: | |||
event = ev_map[row["event_id"]] | |||
if not row["rejects"] and not row["redacts"]: | |||
@@ -1066,7 +1070,7 @@ class PersistEventsStore: | |||
# invalidate the cache for the redacted event | |||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="redactions", | |||
values={ | |||
@@ -1089,7 +1093,7 @@ class PersistEventsStore: | |||
room_id (str): The ID of the room the event was sent to. | |||
topological_ordering (int): The position of the event in the room's topology. | |||
""" | |||
return self.db.simple_insert_many_txn( | |||
return self.db_pool.simple_insert_many_txn( | |||
txn=txn, | |||
table="event_labels", | |||
values=[ | |||
@@ -1111,7 +1115,7 @@ class PersistEventsStore: | |||
event_id (str): The event ID the expiry timestamp is associated with. | |||
expiry_ts (int): The timestamp at which to expire (delete) the event. | |||
""" | |||
return self.db.simple_insert_txn( | |||
return self.db_pool.simple_insert_txn( | |||
txn=txn, | |||
table="event_expiry", | |||
values={"event_id": event_id, "expiry_ts": expiry_ts}, | |||
@@ -1135,12 +1139,14 @@ class PersistEventsStore: | |||
} | |||
) | |||
self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) | |||
self.db_pool.simple_insert_many_txn( | |||
txn, table="event_reference_hashes", values=vals | |||
) | |||
def _store_room_members_txn(self, txn, events, backfilled): | |||
"""Store a room member in the database. | |||
""" | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="room_memberships", | |||
values=[ | |||
@@ -1180,7 +1186,7 @@ class PersistEventsStore: | |||
and event.internal_metadata.is_outlier() | |||
and event.internal_metadata.is_out_of_band_membership() | |||
): | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="local_current_membership", | |||
keyvalues={"room_id": event.room_id, "user_id": event.state_key}, | |||
@@ -1218,7 +1224,7 @@ class PersistEventsStore: | |||
aggregation_key = relation.get("key") | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="event_relations", | |||
values={ | |||
@@ -1246,7 +1252,7 @@ class PersistEventsStore: | |||
redacted_event_id (str): The event that was redacted. | |||
""" | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} | |||
) | |||
@@ -1282,7 +1288,7 @@ class PersistEventsStore: | |||
# Ignore the event if one of the value isn't an integer. | |||
return | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn=txn, | |||
table="room_retention", | |||
values={ | |||
@@ -1363,7 +1369,7 @@ class PersistEventsStore: | |||
) | |||
for event, _ in events_and_contexts: | |||
user_ids = self.db.simple_select_onecol_txn( | |||
user_ids = self.db_pool.simple_select_onecol_txn( | |||
txn, | |||
table="event_push_actions_staging", | |||
keyvalues={"event_id": event.event_id}, | |||
@@ -1395,7 +1401,7 @@ class PersistEventsStore: | |||
) | |||
def _store_rejections_txn(self, txn, event_id, reason): | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="rejections", | |||
values={ | |||
@@ -1421,7 +1427,7 @@ class PersistEventsStore: | |||
state_groups[event.event_id] = context.state_group | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_to_state_groups", | |||
values=[ | |||
@@ -1443,7 +1449,7 @@ class PersistEventsStore: | |||
if min_depth is not None and depth >= min_depth: | |||
return | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="room_depth", | |||
keyvalues={"room_id": room_id}, | |||
@@ -1455,7 +1461,7 @@ class PersistEventsStore: | |||
For the given event, update the event edges table and forward and | |||
backward extremities tables. | |||
""" | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="event_edges", | |||
values=[ |
@@ -19,7 +19,7 @@ from twisted.internet import defer | |||
from synapse.api.constants import EventContentFields | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
logger = logging.getLogger(__name__) | |||
@@ -30,18 +30,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" | |||
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, | |||
self._background_reindex_fields_sender, | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"event_contains_url_index", | |||
index_name="event_contains_url_index", | |||
table="events", | |||
@@ -52,7 +52,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
# an event_id index on event_search is useful for the purge_history | |||
# api. Plus it means we get to enforce some integrity with a UNIQUE | |||
# clause | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"event_search_event_id_idx", | |||
index_name="event_search_event_id_idx", | |||
table="event_search", | |||
@@ -61,16 +61,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
psql_only=True, | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"redactions_received_ts", self._redactions_received_ts | |||
) | |||
# This index gets deleted in `event_fix_redactions_bytes` update | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"event_fix_redactions_bytes_create_index", | |||
index_name="redactions_censored_redacts", | |||
table="redactions", | |||
@@ -78,15 +78,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
where_clause="have_censored", | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"event_fix_redactions_bytes", self._event_fix_redactions_bytes | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"event_store_labels", self._event_store_labels | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"redactions_have_censored_ts_idx", | |||
index_name="redactions_have_censored_ts", | |||
table="redactions", | |||
@@ -149,18 +149,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
"rows_inserted": rows_inserted + len(rows), | |||
} | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress | |||
) | |||
return len(rows) | |||
result = yield self.db.runInteraction( | |||
result = yield self.db_pool.runInteraction( | |||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn | |||
) | |||
if not result: | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.updates._end_background_update( | |||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME | |||
) | |||
@@ -195,7 +195,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] | |||
for chunk in chunks: | |||
ev_rows = self.db.simple_select_many_txn( | |||
ev_rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_json", | |||
column="event_id", | |||
@@ -228,18 +228,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
"rows_inserted": rows_inserted + len(rows_to_update), | |||
} | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress | |||
) | |||
return len(rows_to_update) | |||
result = yield self.db.runInteraction( | |||
result = yield self.db_pool.runInteraction( | |||
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn | |||
) | |||
if not result: | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.updates._end_background_update( | |||
self.EVENT_ORIGIN_SERVER_TS_NAME | |||
) | |||
@@ -374,7 +374,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
to_delete.intersection_update(original_set) | |||
deleted = self.db.simple_delete_many_txn( | |||
deleted = self.db_pool.simple_delete_many_txn( | |||
txn=txn, | |||
table="event_forward_extremities", | |||
column="event_id", | |||
@@ -390,7 +390,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
if deleted: | |||
# We now need to invalidate the caches of these rooms | |||
rows = self.db.simple_select_many_txn( | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="events", | |||
column="event_id", | |||
@@ -404,7 +404,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
self.get_latest_event_ids_in_room.invalidate, (room_id,) | |||
) | |||
self.db.simple_delete_many_txn( | |||
self.db_pool.simple_delete_many_txn( | |||
txn=txn, | |||
table="_extremities_to_check", | |||
column="event_id", | |||
@@ -414,19 +414,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
return len(original_set) | |||
num_handled = yield self.db.runInteraction( | |||
num_handled = yield self.db_pool.runInteraction( | |||
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn | |||
) | |||
if not num_handled: | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.updates._end_background_update( | |||
self.DELETE_SOFT_FAILED_EXTREMITIES | |||
) | |||
def _drop_table_txn(txn): | |||
txn.execute("DROP TABLE _extremities_to_check") | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn | |||
) | |||
@@ -474,18 +474,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, "redactions_received_ts", {"last_event_id": upper_event_id} | |||
) | |||
return len(rows) | |||
count = yield self.db.runInteraction( | |||
count = yield self.db_pool.runInteraction( | |||
"_redactions_received_ts", _redactions_received_ts_txn | |||
) | |||
if not count: | |||
yield self.db.updates._end_background_update("redactions_received_ts") | |||
yield self.db_pool.updates._end_background_update("redactions_received_ts") | |||
return count | |||
@@ -511,11 +511,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
txn.execute("DROP INDEX redactions_censored_redacts") | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn | |||
) | |||
yield self.db.updates._end_background_update("event_fix_redactions_bytes") | |||
yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") | |||
return 1 | |||
@@ -543,7 +543,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
try: | |||
event_json = db_to_json(event_json_raw) | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn=txn, | |||
table="event_labels", | |||
values=[ | |||
@@ -569,17 +569,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
nbrows += 1 | |||
last_row_event_id = event_id | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, "event_store_labels", {"last_event_id": last_row_event_id} | |||
) | |||
return nbrows | |||
num_rows = yield self.db.runInteraction( | |||
num_rows = yield self.db_pool.runInteraction( | |||
desc="event_store_labels", func=_event_store_labels_txn | |||
) | |||
if not num_rows: | |||
yield self.db.updates._end_background_update("event_store_labels") | |||
yield self.db_pool.updates._end_background_update("event_store_labels") | |||
return num_rows |
@@ -40,7 +40,7 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams import BackfillStream | |||
from synapse.replication.tcp.streams.events import EventsStream | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.types import Cursor | |||
from synapse.storage.util.id_generators import StreamIdGenerator | |||
from synapse.types import get_domain_from_id | |||
@@ -80,7 +80,7 @@ class EventRedactBehaviour(Names): | |||
class EventsWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(EventsWorkerStore, self).__init__(database, db_conn, hs) | |||
if hs.config.worker.writers.events == hs.get_instance_name(): | |||
@@ -136,7 +136,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
Deferred[int|None]: Timestamp in milliseconds, or None for events | |||
that were persisted before received_ts was implemented. | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="events", | |||
keyvalues={"event_id": event_id}, | |||
retcol="received_ts", | |||
@@ -175,7 +175,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
return ts | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_approximate_received_ts", _get_approximate_received_ts_txn | |||
) | |||
@@ -543,7 +543,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
event_id for events, _ in event_list for event_id in events | |||
} | |||
row_dict = self.db.new_transaction( | |||
row_dict = self.db_pool.new_transaction( | |||
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch | |||
) | |||
@@ -720,7 +720,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
if should_start: | |||
run_as_background_process( | |||
"fetch_events", self.db.runWithConnection, self._do_fetch | |||
"fetch_events", self.db_pool.runWithConnection, self._do_fetch | |||
) | |||
logger.debug("Loading %d events: %s", len(events), events) | |||
@@ -889,7 +889,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
"""Given a list of event ids, check if we have already processed and | |||
stored them as non outliers. | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="events", | |||
retcols=("event_id",), | |||
column="event_id", | |||
@@ -924,7 +924,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
# break the input up into chunks of 100 | |||
input_iterator = iter(event_ids) | |||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"have_seen_events", have_seen_events_txn, chunk | |||
) | |||
return results | |||
@@ -953,7 +953,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[int] | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_total_state_event_counts", | |||
self._get_total_state_event_counts_txn, | |||
room_id, | |||
@@ -978,7 +978,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[int] | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_current_state_event_counts", | |||
self._get_current_state_event_counts_txn, | |||
room_id, | |||
@@ -1043,7 +1043,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
return txn.fetchall() | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows | |||
) | |||
@@ -1077,7 +1077,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (last_id, current_id)) | |||
return txn.fetchall() | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn | |||
) | |||
@@ -1151,7 +1151,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
return new_event_updates, upper_bound, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows | |||
) | |||
@@ -1199,7 +1199,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
# we need to make sure that, for every stream id in the results, we get *all* | |||
# the rows with that stream id. | |||
rows = await self.db.runInteraction( | |||
rows = await self.db_pool.runInteraction( | |||
"get_all_updated_current_state_deltas", | |||
get_all_updated_current_state_deltas_txn, | |||
) # type: List[Tuple] | |||
@@ -1222,7 +1222,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
# stream id. let's run the query again, without a row limit, but for | |||
# just one stream id. | |||
to_token += 1 | |||
rows = await self.db.runInteraction( | |||
rows = await self.db_pool.runInteraction( | |||
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token | |||
) | |||
@@ -1317,7 +1317,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
backward_ex_outliers, | |||
) | |||
return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) | |||
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) | |||
async def is_event_after(self, event_id1, event_id2): | |||
"""Returns True if event_id1 is after event_id2 in the stream | |||
@@ -1328,7 +1328,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
@cachedInlineCallbacks(max_entries=5000) | |||
def get_event_ordering(self, event_id): | |||
res = yield self.db.simple_select_one( | |||
res = yield self.db_pool.simple_select_one( | |||
table="events", | |||
retcols=["topological_ordering", "stream_ordering"], | |||
keyvalues={"event_id": event_id}, | |||
@@ -1360,7 +1360,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
return txn.fetchone() | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn | |||
) | |||
@@ -1385,7 +1385,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
on_invalidate=cache_context.invalidate, | |||
) | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_unread_message_count_for_user", | |||
self._get_unread_message_count_for_user_txn, | |||
user_id, | |||
@@ -1402,7 +1402,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
) -> int: | |||
if last_read_event_id: | |||
# Get the stream ordering for the last read event. | |||
stream_ordering = self.db.simple_select_one_onecol_txn( | |||
stream_ordering = self.db_pool.simple_select_one_onecol_txn( | |||
txn=txn, | |||
table="events", | |||
keyvalues={"room_id": room_id, "event_id": last_read_event_id}, |
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): | |||
except ValueError: | |||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) | |||
def_json = yield self.db.simple_select_one_onecol( | |||
def_json = yield self.db_pool.simple_select_one_onecol( | |||
table="user_filters", | |||
keyvalues={"user_id": user_localpart, "filter_id": filter_id}, | |||
retcol="filter_json", | |||
@@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore): | |||
return filter_id | |||
return self.db.runInteraction("add_user_filter", _do_txn) | |||
return self.db_pool.runInteraction("add_user_filter", _do_txn) |
@@ -31,7 +31,7 @@ _DEFAULT_ROLE_ID = "" | |||
class GroupServerWorkerStore(SQLBaseStore): | |||
def get_group(self, group_id): | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
table="groups", | |||
keyvalues={"group_id": group_id}, | |||
retcols=( | |||
@@ -53,7 +53,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
if not include_private: | |||
keyvalues["is_public"] = True | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
table="group_users", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "is_public", "is_admin"), | |||
@@ -63,7 +63,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
def get_invited_users_in_group(self, group_id): | |||
# TODO: Pagination | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="group_invites", | |||
keyvalues={"group_id": group_id}, | |||
retcol="user_id", | |||
@@ -117,7 +117,9 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
for room_id, is_public in txn | |||
] | |||
return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) | |||
return self.db_pool.runInteraction( | |||
"get_rooms_in_group", _get_rooms_in_group_txn | |||
) | |||
def get_rooms_for_summary_by_category( | |||
self, group_id: str, include_private: bool = False, | |||
@@ -205,13 +207,13 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
return rooms, categories | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_rooms_for_summary", _get_rooms_for_summary_txn | |||
) | |||
@defer.inlineCallbacks | |||
def get_group_categories(self, group_id): | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="group_room_categories", | |||
keyvalues={"group_id": group_id}, | |||
retcols=("category_id", "is_public", "profile"), | |||
@@ -228,7 +230,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_group_category(self, group_id, category_id): | |||
category = yield self.db.simple_select_one( | |||
category = yield self.db_pool.simple_select_one( | |||
table="group_room_categories", | |||
keyvalues={"group_id": group_id, "category_id": category_id}, | |||
retcols=("is_public", "profile"), | |||
@@ -241,7 +243,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_group_roles(self, group_id): | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="group_roles", | |||
keyvalues={"group_id": group_id}, | |||
retcols=("role_id", "is_public", "profile"), | |||
@@ -258,7 +260,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_group_role(self, group_id, role_id): | |||
role = yield self.db.simple_select_one( | |||
role = yield self.db_pool.simple_select_one( | |||
table="group_roles", | |||
keyvalues={"group_id": group_id, "role_id": role_id}, | |||
retcols=("is_public", "profile"), | |||
@@ -277,7 +279,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
Deferred[list[str]]: A twisted.Deferred containing a list of group ids | |||
containing this room | |||
""" | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="group_rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcol="group_id", | |||
@@ -341,12 +343,12 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
return users, roles | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_users_for_summary_by_role", _get_users_for_summary_txn | |||
) | |||
def is_user_in_group(self, user_id, group_id): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="group_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
retcol="user_id", | |||
@@ -355,7 +357,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
).addCallback(lambda r: bool(r)) | |||
def is_user_admin_in_group(self, group_id, user_id): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="group_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
retcol="is_admin", | |||
@@ -366,7 +368,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
def is_user_invited_to_local_group(self, group_id, user_id): | |||
"""Has the group server invited a user? | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="group_invites", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
retcol="user_id", | |||
@@ -389,7 +391,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
""" | |||
def _get_users_membership_in_group_txn(txn): | |||
row = self.db.simple_select_one_txn( | |||
row = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="group_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
@@ -404,7 +406,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
"is_privileged": row["is_admin"], | |||
} | |||
row = self.db.simple_select_one_onecol_txn( | |||
row = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_invites", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
@@ -417,14 +419,14 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
return {} | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_users_membership_info_in_group", _get_users_membership_in_group_txn | |||
) | |||
def get_publicised_groups_for_user(self, user_id): | |||
"""Get all groups a user is publicising | |||
""" | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="local_group_membership", | |||
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, | |||
retcol="group_id", | |||
@@ -441,9 +443,9 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
WHERE valid_until_ms <= ? | |||
""" | |||
txn.execute(sql, (valid_until_ms,)) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_attestations_need_renewals", _get_attestations_need_renewals_txn | |||
) | |||
@@ -452,7 +454,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
"""Get the attestation that proves the remote agrees that the user is | |||
in the group. | |||
""" | |||
row = yield self.db.simple_select_one( | |||
row = yield self.db_pool.simple_select_one( | |||
table="group_attestations_remote", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
retcols=("valid_until_ms", "attestation_json"), | |||
@@ -467,7 +469,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
return None | |||
def get_joined_groups(self, user_id): | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="local_group_membership", | |||
keyvalues={"user_id": user_id, "membership": "join"}, | |||
retcol="group_id", | |||
@@ -494,7 +496,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
for row in txn | |||
] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_all_groups_for_user", _get_all_groups_for_user_txn | |||
) | |||
@@ -524,7 +526,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
for group_id, membership, gtype, content_json in txn | |||
] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_groups_changes_for_user", _get_groups_changes_for_user_txn | |||
) | |||
@@ -579,7 +581,7 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_groups_changes", _get_all_groups_changes_txn | |||
) | |||
@@ -592,7 +594,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
* "invite" | |||
* "open" | |||
""" | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="groups", | |||
keyvalues={"group_id": group_id}, | |||
updatevalues={"join_policy": join_policy}, | |||
@@ -600,7 +602,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
) | |||
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"add_room_to_summary", | |||
self._add_room_to_summary_txn, | |||
group_id, | |||
@@ -624,7 +626,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
an order of 1 will put the room first. Otherwise, the room gets | |||
added to the end. | |||
""" | |||
room_in_group = self.db.simple_select_one_onecol_txn( | |||
room_in_group = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_rooms", | |||
keyvalues={"group_id": group_id, "room_id": room_id}, | |||
@@ -637,7 +639,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if category_id is None: | |||
category_id = _DEFAULT_CATEGORY_ID | |||
else: | |||
cat_exists = self.db.simple_select_one_onecol_txn( | |||
cat_exists = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_room_categories", | |||
keyvalues={"group_id": group_id, "category_id": category_id}, | |||
@@ -648,7 +650,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
raise SynapseError(400, "Category doesn't exist") | |||
# TODO: Check category is part of summary already | |||
cat_exists = self.db.simple_select_one_onecol_txn( | |||
cat_exists = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_summary_room_categories", | |||
keyvalues={"group_id": group_id, "category_id": category_id}, | |||
@@ -668,7 +670,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
(group_id, category_id, group_id, category_id), | |||
) | |||
existing = self.db.simple_select_one_txn( | |||
existing = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="group_summary_rooms", | |||
keyvalues={ | |||
@@ -701,7 +703,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
to_update["room_order"] = order | |||
if is_public is not None: | |||
to_update["is_public"] = is_public | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn, | |||
table="group_summary_rooms", | |||
keyvalues={ | |||
@@ -715,7 +717,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if is_public is None: | |||
is_public = True | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_summary_rooms", | |||
values={ | |||
@@ -731,7 +733,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if category_id is None: | |||
category_id = _DEFAULT_CATEGORY_ID | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="group_summary_rooms", | |||
keyvalues={ | |||
"group_id": group_id, | |||
@@ -757,7 +759,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
else: | |||
update_values["is_public"] = is_public | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="group_room_categories", | |||
keyvalues={"group_id": group_id, "category_id": category_id}, | |||
values=update_values, | |||
@@ -766,7 +768,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
) | |||
def remove_group_category(self, group_id, category_id): | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="group_room_categories", | |||
keyvalues={"group_id": group_id, "category_id": category_id}, | |||
desc="remove_group_category", | |||
@@ -788,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
else: | |||
update_values["is_public"] = is_public | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="group_roles", | |||
keyvalues={"group_id": group_id, "role_id": role_id}, | |||
values=update_values, | |||
@@ -797,14 +799,14 @@ class GroupServerStore(GroupServerWorkerStore): | |||
) | |||
def remove_group_role(self, group_id, role_id): | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="group_roles", | |||
keyvalues={"group_id": group_id, "role_id": role_id}, | |||
desc="remove_group_role", | |||
) | |||
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"add_user_to_summary", | |||
self._add_user_to_summary_txn, | |||
group_id, | |||
@@ -828,7 +830,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
an order of 1 will put the user first. Otherwise, the user gets | |||
added to the end. | |||
""" | |||
user_in_group = self.db.simple_select_one_onecol_txn( | |||
user_in_group = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
@@ -841,7 +843,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if role_id is None: | |||
role_id = _DEFAULT_ROLE_ID | |||
else: | |||
role_exists = self.db.simple_select_one_onecol_txn( | |||
role_exists = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_roles", | |||
keyvalues={"group_id": group_id, "role_id": role_id}, | |||
@@ -852,7 +854,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
raise SynapseError(400, "Role doesn't exist") | |||
# TODO: Check role is part of the summary already | |||
role_exists = self.db.simple_select_one_onecol_txn( | |||
role_exists = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="group_summary_roles", | |||
keyvalues={"group_id": group_id, "role_id": role_id}, | |||
@@ -872,7 +874,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
(group_id, role_id, group_id, role_id), | |||
) | |||
existing = self.db.simple_select_one_txn( | |||
existing = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="group_summary_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, | |||
@@ -901,7 +903,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
to_update["user_order"] = order | |||
if is_public is not None: | |||
to_update["is_public"] = is_public | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn, | |||
table="group_summary_users", | |||
keyvalues={ | |||
@@ -915,7 +917,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if is_public is None: | |||
is_public = True | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_summary_users", | |||
values={ | |||
@@ -931,7 +933,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if role_id is None: | |||
role_id = _DEFAULT_ROLE_ID | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="group_summary_users", | |||
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, | |||
desc="remove_user_from_summary", | |||
@@ -940,7 +942,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def add_group_invite(self, group_id, user_id): | |||
"""Record that the group server has invited a user | |||
""" | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="group_invites", | |||
values={"group_id": group_id, "user_id": user_id}, | |||
desc="add_group_invite", | |||
@@ -970,7 +972,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
""" | |||
def _add_user_to_group_txn(txn): | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_users", | |||
values={ | |||
@@ -981,14 +983,14 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_invites", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
if local_attestation: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_attestations_renewals", | |||
values={ | |||
@@ -998,7 +1000,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
if remote_attestation: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_attestations_remote", | |||
values={ | |||
@@ -1009,49 +1011,49 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) | |||
return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) | |||
def remove_user_from_group(self, group_id, user_id): | |||
def _remove_user_from_group_txn(txn): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_invites", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_attestations_renewals", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_attestations_remote", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_summary_users", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"remove_user_from_group", _remove_user_from_group_txn | |||
) | |||
def add_room_to_group(self, group_id, room_id, is_public): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="group_rooms", | |||
values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, | |||
desc="add_room_to_group", | |||
) | |||
def update_room_in_group_visibility(self, group_id, room_id, is_public): | |||
return self.db.simple_update( | |||
return self.db_pool.simple_update( | |||
table="group_rooms", | |||
keyvalues={"group_id": group_id, "room_id": room_id}, | |||
updatevalues={"is_public": is_public}, | |||
@@ -1060,26 +1062,26 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def remove_room_from_group(self, group_id, room_id): | |||
def _remove_room_from_group_txn(txn): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_rooms", | |||
keyvalues={"group_id": group_id, "room_id": room_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_summary_rooms", | |||
keyvalues={"group_id": group_id, "room_id": room_id}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"remove_room_from_group", _remove_room_from_group_txn | |||
) | |||
def update_group_publicity(self, group_id, user_id, publicise): | |||
"""Update whether the user is publicising their membership of the group | |||
""" | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="local_group_membership", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
updatevalues={"is_publicised": publicise}, | |||
@@ -1115,12 +1117,12 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def _register_user_group_membership_txn(txn, next_id): | |||
# TODO: Upsert? | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="local_group_membership", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="local_group_membership", | |||
values={ | |||
@@ -1133,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="local_group_updates", | |||
values={ | |||
@@ -1152,7 +1154,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
if membership == "join": | |||
if local_attestation: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_attestations_renewals", | |||
values={ | |||
@@ -1162,7 +1164,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
if remote_attestation: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="group_attestations_remote", | |||
values={ | |||
@@ -1173,12 +1175,12 @@ class GroupServerStore(GroupServerWorkerStore): | |||
}, | |||
) | |||
else: | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_attestations_renewals", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="group_attestations_remote", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
@@ -1187,7 +1189,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
return next_id | |||
with self._group_updates_id_gen.get_next() as next_id: | |||
res = yield self.db.runInteraction( | |||
res = yield self.db_pool.runInteraction( | |||
"register_user_group_membership", | |||
_register_user_group_membership_txn, | |||
next_id, | |||
@@ -1198,7 +1200,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def create_group( | |||
self, group_id, user_id, name, avatar_url, short_description, long_description | |||
): | |||
yield self.db.simple_insert( | |||
yield self.db_pool.simple_insert( | |||
table="groups", | |||
values={ | |||
"group_id": group_id, | |||
@@ -1213,7 +1215,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
@defer.inlineCallbacks | |||
def update_group_profile(self, group_id, profile): | |||
yield self.db.simple_update_one( | |||
yield self.db_pool.simple_update_one( | |||
table="groups", | |||
keyvalues={"group_id": group_id}, | |||
updatevalues=profile, | |||
@@ -1223,7 +1225,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def update_attestation_renewal(self, group_id, user_id, attestation): | |||
"""Update an attestation that we have renewed | |||
""" | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="group_attestations_renewals", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, | |||
@@ -1233,7 +1235,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
def update_remote_attestion(self, group_id, user_id, attestation): | |||
"""Update an attestation that a remote has renewed | |||
""" | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="group_attestations_remote", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
updatevalues={ | |||
@@ -1252,7 +1254,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
group_id (str) | |||
user_id (str) | |||
""" | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="group_attestations_renewals", | |||
keyvalues={"group_id": group_id, "user_id": user_id}, | |||
desc="remove_attestation_renewal", | |||
@@ -1288,8 +1290,8 @@ class GroupServerStore(GroupServerWorkerStore): | |||
] | |||
for table in tables: | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, table=table, keyvalues={"group_id": group_id} | |||
) | |||
return self.db.runInteraction("delete_group", _delete_group_txn) | |||
return self.db_pool.runInteraction("delete_group", _delete_group_txn) |
@@ -86,7 +86,7 @@ class KeyStore(SQLBaseStore): | |||
_get_keys(txn, batch) | |||
return keys | |||
return self.db.runInteraction("get_server_verify_keys", _txn) | |||
return self.db_pool.runInteraction("get_server_verify_keys", _txn) | |||
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): | |||
"""Stores NACL verification keys for remote servers. | |||
@@ -121,9 +121,9 @@ class KeyStore(SQLBaseStore): | |||
f((i,)) | |||
return res | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"store_server_verify_keys", | |||
self.db.simple_upsert_many_txn, | |||
self.db_pool.simple_upsert_many_txn, | |||
table="server_signature_keys", | |||
key_names=("server_name", "key_id"), | |||
key_values=key_values, | |||
@@ -151,7 +151,7 @@ class KeyStore(SQLBaseStore): | |||
ts_valid_until_ms (int): The time when this json stops being valid. | |||
key_json (bytes): The encoded JSON. | |||
""" | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="server_keys_json", | |||
keyvalues={ | |||
"server_name": server_name, | |||
@@ -190,7 +190,7 @@ class KeyStore(SQLBaseStore): | |||
keyvalues["key_id"] = key_id | |||
if from_server is not None: | |||
keyvalues["from_server"] = from_server | |||
rows = self.db.simple_select_list_txn( | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"server_keys_json", | |||
keyvalues=keyvalues, | |||
@@ -205,4 +205,6 @@ class KeyStore(SQLBaseStore): | |||
results[(server_name, key_id, from_server)] = rows | |||
return results | |||
return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) | |||
return self.db_pool.runInteraction( | |||
"get_server_keys_json", _get_server_keys_json_txn | |||
) |
@@ -13,16 +13,16 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(MediaRepositoryBackgroundUpdateStore, self).__init__( | |||
database, db_conn, hs | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
update_name="local_media_repository_url_idx", | |||
index_name="local_media_repository_url_idx", | |||
table="local_media_repository", | |||
@@ -34,7 +34,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): | |||
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
"""Persistence for attachments and avatars""" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(MediaRepositoryStore, self).__init__(database, db_conn, hs) | |||
def get_local_media(self, media_id): | |||
@@ -42,7 +42,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
Returns: | |||
None if the media_id doesn't exist. | |||
""" | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
"local_media_repository", | |||
{"media_id": media_id}, | |||
( | |||
@@ -67,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
user_id, | |||
url_cache=None, | |||
): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"local_media_repository", | |||
{ | |||
"media_id": media_id, | |||
@@ -83,7 +83,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
def mark_local_media_as_safe(self, media_id: str): | |||
"""Mark a local media as safe from quarantining.""" | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="local_media_repository", | |||
keyvalues={"media_id": media_id}, | |||
updatevalues={"safe_from_quarantine": True}, | |||
@@ -136,12 +136,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
) | |||
return self.db.runInteraction("get_url_cache", get_url_cache_txn) | |||
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) | |||
def store_url_cache( | |||
self, url, response_code, etag, expires_ts, og, media_id, download_ts | |||
): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"local_media_repository_url_cache", | |||
{ | |||
"url": url, | |||
@@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
def get_local_media_thumbnails(self, media_id): | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
"local_media_repository_thumbnails", | |||
{"media_id": media_id}, | |||
( | |||
@@ -178,7 +178,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
thumbnail_method, | |||
thumbnail_length, | |||
): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"local_media_repository_thumbnails", | |||
{ | |||
"media_id": media_id, | |||
@@ -192,7 +192,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
def get_cached_remote_media(self, origin, media_id): | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
"remote_media_cache", | |||
{"media_origin": origin, "media_id": media_id}, | |||
( | |||
@@ -217,7 +217,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
upload_name, | |||
filesystem_id, | |||
): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"remote_media_cache", | |||
{ | |||
"media_origin": origin, | |||
@@ -262,12 +262,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"update_cached_last_access_time", update_cache_txn | |||
) | |||
def get_remote_media_thumbnails(self, origin, media_id): | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
"remote_media_cache_thumbnails", | |||
{"media_origin": origin, "media_id": media_id}, | |||
( | |||
@@ -292,7 +292,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
thumbnail_method, | |||
thumbnail_length, | |||
): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"remote_media_cache_thumbnails", | |||
{ | |||
"media_origin": origin, | |||
@@ -314,24 +314,26 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
" WHERE last_access_ts < ?" | |||
) | |||
return self.db.execute( | |||
"get_remote_media_before", self.db.cursor_to_dict, sql, before_ts | |||
return self.db_pool.execute( | |||
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts | |||
) | |||
def delete_remote_media(self, media_origin, media_id): | |||
def delete_remote_media_txn(txn): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
"remote_media_cache", | |||
keyvalues={"media_origin": media_origin, "media_id": media_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
"remote_media_cache_thumbnails", | |||
keyvalues={"media_origin": media_origin, "media_id": media_id}, | |||
) | |||
return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) | |||
return self.db_pool.runInteraction( | |||
"delete_remote_media", delete_remote_media_txn | |||
) | |||
def get_expired_url_cache(self, now_ts): | |||
sql = ( | |||
@@ -345,7 +347,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
txn.execute(sql, (now_ts,)) | |||
return [row[0] for row in txn] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_expired_url_cache", _get_expired_url_cache_txn | |||
) | |||
@@ -358,7 +360,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
def _delete_url_cache_txn(txn): | |||
txn.executemany(sql, [(media_id,) for media_id in media_ids]) | |||
return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) | |||
return await self.db_pool.runInteraction( | |||
"delete_url_cache", _delete_url_cache_txn | |||
) | |||
def get_url_cache_media_before(self, before_ts): | |||
sql = ( | |||
@@ -372,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
txn.execute(sql, (before_ts,)) | |||
return [row[0] for row in txn] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_url_cache_media_before", _get_url_cache_media_before_txn | |||
) | |||
@@ -389,6 +393,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
txn.executemany(sql, [(media_id,) for media_id in media_ids]) | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"delete_url_cache_media", _delete_url_cache_media_txn | |||
) |
@@ -20,10 +20,10 @@ from twisted.internet import defer | |||
from synapse.metrics import BucketCollector | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.event_push_actions import ( | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.event_push_actions import ( | |||
EventPushActionsWorkerStore, | |||
) | |||
from synapse.storage.database import Database | |||
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
@@ -31,7 +31,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
stats and prometheus metrics. | |||
""" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super().__init__(database, db_conn, hs) | |||
# Collect metrics on the number of forward extremities that exist. | |||
@@ -66,7 +66,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
) | |||
return txn.fetchall() | |||
res = await self.db.runInteraction("read_forward_extremities", fetch) | |||
res = await self.db_pool.runInteraction("read_forward_extremities", fetch) | |||
self._current_forward_extremities_amount = Counter([x[0] for x in res]) | |||
@defer.inlineCallbacks | |||
@@ -88,7 +88,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
(count,) = txn.fetchone() | |||
return count | |||
ret = yield self.db.runInteraction("count_messages", _count_messages) | |||
ret = yield self.db_pool.runInteraction("count_messages", _count_messages) | |||
return ret | |||
@defer.inlineCallbacks | |||
@@ -109,7 +109,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
(count,) = txn.fetchone() | |||
return count | |||
ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) | |||
ret = yield self.db_pool.runInteraction( | |||
"count_daily_sent_messages", _count_messages | |||
) | |||
return ret | |||
@defer.inlineCallbacks | |||
@@ -124,5 +126,5 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
(count,) = txn.fetchone() | |||
return count | |||
ret = yield self.db.runInteraction("count_daily_active_rooms", _count) | |||
ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) | |||
return ret |
@@ -18,7 +18,7 @@ from typing import List | |||
from twisted.internet import defer | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import Database, make_in_list_sql_clause | |||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause | |||
from synapse.util.caches.descriptors import cached | |||
logger = logging.getLogger(__name__) | |||
@@ -29,7 +29,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 | |||
class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) | |||
self._clock = hs.get_clock() | |||
self.hs = hs | |||
@@ -48,7 +48,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
(count,) = txn.fetchone() | |||
return count | |||
return self.db.runInteraction("count_users", _count_users) | |||
return self.db_pool.runInteraction("count_users", _count_users) | |||
@cached(num_args=0) | |||
def get_monthly_active_count_by_service(self): | |||
@@ -76,7 +76,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
result = txn.fetchall() | |||
return dict(result) | |||
return self.db.runInteraction("count_users_by_service", _count_users_by_service) | |||
return self.db_pool.runInteraction( | |||
"count_users_by_service", _count_users_by_service | |||
) | |||
async def get_registered_reserved_users(self) -> List[str]: | |||
"""Of the reserved threepids defined in config, retrieve those that are associated | |||
@@ -109,7 +111,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="monthly_active_users", | |||
keyvalues={"user_id": user_id}, | |||
retcol="timestamp", | |||
@@ -119,7 +121,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) | |||
self._limit_usage_by_mau = hs.config.limit_usage_by_mau | |||
@@ -128,7 +130,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
# Do not add more reserved users than the total allowable number | |||
# cur = LoggingTransaction( | |||
self.db.new_transaction( | |||
self.db_pool.new_transaction( | |||
db_conn, | |||
"initialise_mau_threepids", | |||
[], | |||
@@ -162,7 +164,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
is_support = self.is_support_user_txn(txn, user_id) | |||
if not is_support: | |||
# We do this manually here to avoid hitting #6791 | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="monthly_active_users", | |||
keyvalues={"user_id": user_id}, | |||
@@ -246,7 +248,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) | |||
reserved_users = await self.get_registered_reserved_users() | |||
await self.db.runInteraction( | |||
await self.db_pool.runInteraction( | |||
"reap_monthly_active_users", _reap_users, reserved_users | |||
) | |||
@@ -273,7 +275,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
if is_support: | |||
return | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id | |||
) | |||
@@ -303,7 +305,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
# never be a big table and alternative approaches (batching multiple | |||
# upserts into a single txn) introduced a lot of extra complexity. | |||
# See https://github.com/matrix-org/synapse/issues/3854 for more | |||
is_insert = self.db.simple_upsert_txn( | |||
is_insert = self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="monthly_active_users", | |||
keyvalues={"user_id": user_id}, |
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore | |||
class OpenIdStore(SQLBaseStore): | |||
def insert_open_id_token(self, token, ts_valid_until_ms, user_id): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="open_id_tokens", | |||
values={ | |||
"token": token, | |||
@@ -28,6 +28,6 @@ class OpenIdStore(SQLBaseStore): | |||
else: | |||
return rows[0][0] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_user_id_for_token", get_user_id_for_token_txn | |||
) |
@@ -31,7 +31,7 @@ class PresenceStore(SQLBaseStore): | |||
) | |||
with stream_ordering_manager as stream_orderings: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"update_presence", | |||
self._update_presence_txn, | |||
stream_orderings, | |||
@@ -48,7 +48,7 @@ class PresenceStore(SQLBaseStore): | |||
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) | |||
# Actually insert new rows | |||
self.db.simple_insert_many_txn( | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="presence_stream", | |||
values=[ | |||
@@ -124,7 +124,7 @@ class PresenceStore(SQLBaseStore): | |||
return updates, upper_bound, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_presence_updates", get_all_presence_updates_txn | |||
) | |||
@@ -139,7 +139,7 @@ class PresenceStore(SQLBaseStore): | |||
inlineCallbacks=True, | |||
) | |||
def get_presence_for_users(self, user_ids): | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="presence_stream", | |||
column="user_id", | |||
iterable=user_ids, | |||
@@ -165,7 +165,7 @@ class PresenceStore(SQLBaseStore): | |||
return self._presence_id_gen.get_current_token() | |||
def allow_presence_visible(self, observed_localpart, observer_userid): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="presence_allow_inbound", | |||
values={ | |||
"observed_user_id": observed_localpart, | |||
@@ -176,7 +176,7 @@ class PresenceStore(SQLBaseStore): | |||
) | |||
def disallow_presence_visible(self, observed_localpart, observer_userid): | |||
return self.db.simple_delete_one( | |||
return self.db_pool.simple_delete_one( | |||
table="presence_allow_inbound", | |||
keyvalues={ | |||
"observed_user_id": observed_localpart, |
@@ -17,14 +17,14 @@ from twisted.internet import defer | |||
from synapse.api.errors import StoreError | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.roommember import ProfileInfo | |||
from synapse.storage.databases.main.roommember import ProfileInfo | |||
class ProfileWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_profileinfo(self, user_localpart): | |||
try: | |||
profile = yield self.db.simple_select_one( | |||
profile = yield self.db_pool.simple_select_one( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
retcols=("displayname", "avatar_url"), | |||
@@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
) | |||
def get_profile_displayname(self, user_localpart): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
retcol="displayname", | |||
@@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
) | |||
def get_profile_avatar_url(self, user_localpart): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
retcol="avatar_url", | |||
@@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
) | |||
def get_from_remote_profile_cache(self, user_id): | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
table="remote_profile_cache", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("displayname", "avatar_url"), | |||
@@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): | |||
) | |||
def create_profile(self, user_localpart): | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="profiles", values={"user_id": user_localpart}, desc="create_profile" | |||
) | |||
def set_profile_displayname(self, user_localpart, new_displayname): | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
updatevalues={"displayname": new_displayname}, | |||
@@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
) | |||
def set_profile_avatar_url(self, user_localpart, new_avatar_url): | |||
return self.db.simple_update_one( | |||
return self.db_pool.simple_update_one( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
updatevalues={"avatar_url": new_avatar_url}, | |||
@@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): | |||
This should only be called when `is_subscribed_remote_profile_for_user` | |||
would return true for the user. | |||
""" | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="remote_profile_cache", | |||
keyvalues={"user_id": user_id}, | |||
values={ | |||
@@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): | |||
) | |||
def update_remote_profile_cache(self, user_id, displayname, avatar_url): | |||
return self.db.simple_update( | |||
return self.db_pool.simple_update( | |||
table="remote_profile_cache", | |||
keyvalues={"user_id": user_id}, | |||
updatevalues={ | |||
@@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): | |||
""" | |||
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) | |||
if not subscribed: | |||
yield self.db.simple_delete( | |||
yield self.db_pool.simple_delete( | |||
table="remote_profile_cache", | |||
keyvalues={"user_id": user_id}, | |||
desc="delete_remote_profile_cache", | |||
@@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore): | |||
txn.execute(sql, (last_checked,)) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_remote_profile_cache_entries_that_expire", | |||
_get_remote_profile_cache_entries_that_expire_txn, | |||
) | |||
@@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): | |||
def is_subscribed_remote_profile_for_user(self, user_id): | |||
"""Check whether we are interested in a remote user's profile. | |||
""" | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="group_users", | |||
keyvalues={"user_id": user_id}, | |||
retcol="user_id", | |||
@@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): | |||
if res: | |||
return True | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="group_invites", | |||
keyvalues={"user_id": user_id}, | |||
retcol="user_id", |
@@ -18,7 +18,7 @@ from typing import Any, Tuple | |||
from synapse.api.errors import SynapseError | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.state import StateGroupWorkerStore | |||
from synapse.storage.databases.main.state import StateGroupWorkerStore | |||
from synapse.types import RoomStreamToken | |||
logger = logging.getLogger(__name__) | |||
@@ -43,7 +43,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): | |||
deleted events. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"purge_history", | |||
self._purge_history_txn, | |||
room_id, | |||
@@ -293,7 +293,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): | |||
Deferred[List[int]]: The list of state groups to delete. | |||
""" | |||
return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) | |||
return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) | |||
def _purge_room_txn(self, txn, room_id): | |||
# First we fetch all the state groups that should be deleted, before |
@@ -25,12 +25,12 @@ from twisted.internet import defer | |||
from synapse.push.baserules import list_with_base_rules | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.data_stores.main.pusher import PusherWorkerStore | |||
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore | |||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.databases.main.pusher import PusherWorkerStore | |||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore | |||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore | |||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException | |||
from synapse.storage.util.id_generators import ChainedIdGenerator | |||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList | |||
@@ -79,7 +79,7 @@ class PushRulesWorkerStore( | |||
# the abstract methods being implemented. | |||
__metaclass__ = abc.ABCMeta | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) | |||
if hs.config.worker.worker_app is None: | |||
@@ -91,7 +91,7 @@ class PushRulesWorkerStore( | |||
db_conn, "push_rules_stream", "stream_id" | |||
) | |||
push_rules_prefill, push_rules_id = self.db.get_cache_dict( | |||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( | |||
db_conn, | |||
"push_rules_stream", | |||
entity_column="user_id", | |||
@@ -116,7 +116,7 @@ class PushRulesWorkerStore( | |||
@cachedInlineCallbacks(max_entries=5000) | |||
def get_push_rules_for_user(self, user_id): | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="push_rules", | |||
keyvalues={"user_name": user_id}, | |||
retcols=( | |||
@@ -140,7 +140,7 @@ class PushRulesWorkerStore( | |||
@cachedInlineCallbacks(max_entries=5000) | |||
def get_push_rules_enabled_for_user(self, user_id): | |||
results = yield self.db.simple_select_list( | |||
results = yield self.db_pool.simple_select_list( | |||
table="push_rules_enable", | |||
keyvalues={"user_name": user_id}, | |||
retcols=("user_name", "rule_id", "enabled"), | |||
@@ -162,7 +162,7 @@ class PushRulesWorkerStore( | |||
(count,) = txn.fetchone() | |||
return bool(count) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"have_push_rules_changed", have_push_rules_changed_txn | |||
) | |||
@@ -178,7 +178,7 @@ class PushRulesWorkerStore( | |||
results = {user_id: [] for user_id in user_ids} | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="push_rules", | |||
column="user_name", | |||
iterable=user_ids, | |||
@@ -336,7 +336,7 @@ class PushRulesWorkerStore( | |||
results = {user_id: {} for user_id in user_ids} | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="push_rules_enable", | |||
column="user_name", | |||
iterable=user_ids, | |||
@@ -394,7 +394,7 @@ class PushRulesWorkerStore( | |||
return updates, upper_bound, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_push_rule_updates", get_all_push_rule_updates_txn | |||
) | |||
@@ -416,7 +416,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
with self._push_rules_stream_id_gen.get_next() as ids: | |||
stream_id, event_stream_ordering = ids | |||
if before or after: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_add_push_rule_relative_txn", | |||
self._add_push_rule_relative_txn, | |||
stream_id, | |||
@@ -430,7 +430,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
after, | |||
) | |||
else: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_add_push_rule_highest_priority_txn", | |||
self._add_push_rule_highest_priority_txn, | |||
stream_id, | |||
@@ -461,7 +461,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
relative_to_rule = before or after | |||
res = self.db.simple_select_one_txn( | |||
res = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="push_rules", | |||
keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, | |||
@@ -584,7 +584,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
# We didn't update a row with the given rule_id so insert one | |||
push_rule_id = self._push_rule_id_gen.get_next() | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="push_rules", | |||
values={ | |||
@@ -627,7 +627,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
""" | |||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering): | |||
self.db.simple_delete_one_txn( | |||
self.db_pool.simple_delete_one_txn( | |||
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} | |||
) | |||
@@ -637,7 +637,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
with self._push_rules_stream_id_gen.get_next() as ids: | |||
stream_id, event_stream_ordering = ids | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"delete_push_rule", | |||
delete_push_rule_txn, | |||
stream_id, | |||
@@ -648,7 +648,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
def set_push_rule_enabled(self, user_id, rule_id, enabled): | |||
with self._push_rules_stream_id_gen.get_next() as ids: | |||
stream_id, event_stream_ordering = ids | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_set_push_rule_enabled_txn", | |||
self._set_push_rule_enabled_txn, | |||
stream_id, | |||
@@ -662,7 +662,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled | |||
): | |||
new_id = self._push_rules_enable_id_gen.get_next() | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
"push_rules_enable", | |||
{"user_name": user_id, "rule_id": rule_id}, | |||
@@ -702,7 +702,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
update_stream=False, | |||
) | |||
else: | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
"push_rules", | |||
{"user_name": user_id, "rule_id": rule_id}, | |||
@@ -721,7 +721,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
with self._push_rules_stream_id_gen.get_next() as ids: | |||
stream_id, event_stream_ordering = ids | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"set_push_rule_actions", | |||
set_push_rule_actions_txn, | |||
stream_id, | |||
@@ -741,7 +741,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
if data is not None: | |||
values.update(data) | |||
self.db.simple_insert_txn(txn, "push_rules_stream", values=values) | |||
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) | |||
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) | |||
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) |
@@ -50,7 +50,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def user_has_pusher(self, user_id): | |||
ret = yield self.db.simple_select_one_onecol( | |||
ret = yield self.db_pool.simple_select_one_onecol( | |||
"pushers", {"user_name": user_id}, "id", allow_none=True | |||
) | |||
return ret is not None | |||
@@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_pushers_by(self, keyvalues): | |||
ret = yield self.db.simple_select_list( | |||
ret = yield self.db_pool.simple_select_list( | |||
"pushers", | |||
keyvalues, | |||
[ | |||
@@ -91,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore): | |||
def get_all_pushers(self): | |||
def get_pushers(txn): | |||
txn.execute("SELECT * FROM pushers") | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
return self._decode_pushers_rows(rows) | |||
rows = yield self.db.runInteraction("get_all_pushers", get_pushers) | |||
rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers) | |||
return rows | |||
async def get_all_updated_pushers_rows( | |||
@@ -160,7 +160,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
return updates, upper_bound, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn | |||
) | |||
@@ -176,7 +176,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
inlineCallbacks=True, | |||
) | |||
def get_if_users_have_pushers(self, user_ids): | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="pushers", | |||
column="user_name", | |||
iterable=user_ids, | |||
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
def update_pusher_last_stream_ordering( | |||
self, app_id, pushkey, user_id, last_stream_ordering | |||
): | |||
yield self.db.simple_update_one( | |||
yield self.db_pool.simple_update_one( | |||
"pushers", | |||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, | |||
{"last_stream_ordering": last_stream_ordering}, | |||
@@ -216,7 +216,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[bool]: True if the pusher still exists; False if it has been deleted. | |||
""" | |||
updated = yield self.db.simple_update( | |||
updated = yield self.db_pool.simple_update( | |||
table="pushers", | |||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, | |||
updatevalues={ | |||
@@ -230,7 +230,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): | |||
yield self.db.simple_update( | |||
yield self.db_pool.simple_update( | |||
table="pushers", | |||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, | |||
updatevalues={"failing_since": failing_since}, | |||
@@ -239,7 +239,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def get_throttle_params_by_room(self, pusher_id): | |||
res = yield self.db.simple_select_list( | |||
res = yield self.db_pool.simple_select_list( | |||
"pusher_throttle", | |||
{"pusher": pusher_id}, | |||
["room_id", "last_sent_ts", "throttle_ms"], | |||
@@ -259,7 +259,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
def set_throttle_params(self, pusher_id, room_id, params): | |||
# no need to lock because `pusher_throttle` has a primary key on | |||
# (pusher, room_id) so simple_upsert will retry | |||
yield self.db.simple_upsert( | |||
yield self.db_pool.simple_upsert( | |||
"pusher_throttle", | |||
{"pusher": pusher_id, "room_id": room_id}, | |||
params, | |||
@@ -291,7 +291,7 @@ class PusherStore(PusherWorkerStore): | |||
with self._pushers_id_gen.get_next() as stream_id: | |||
# no need to lock because `pushers` has a unique key on | |||
# (app_id, pushkey, user_name) so simple_upsert will retry | |||
yield self.db.simple_upsert( | |||
yield self.db_pool.simple_upsert( | |||
table="pushers", | |||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, | |||
values={ | |||
@@ -316,7 +316,7 @@ class PusherStore(PusherWorkerStore): | |||
if user_has_pusher is not True: | |||
# invalidate, since we the user might not have had a pusher before | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"add_pusher", | |||
self._invalidate_cache_and_stream, | |||
self.get_if_user_has_pusher, | |||
@@ -330,7 +330,7 @@ class PusherStore(PusherWorkerStore): | |||
txn, self.get_if_user_has_pusher, (user_id,) | |||
) | |||
self.db.simple_delete_one_txn( | |||
self.db_pool.simple_delete_one_txn( | |||
txn, | |||
"pushers", | |||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, | |||
@@ -339,7 +339,7 @@ class PusherStore(PusherWorkerStore): | |||
# it's possible for us to end up with duplicate rows for | |||
# (app_id, pushkey, user_id) at different stream_ids, but that | |||
# doesn't really matter. | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="deleted_pushers", | |||
values={ | |||
@@ -351,4 +351,6 @@ class PusherStore(PusherWorkerStore): | |||
) | |||
with self._pushers_id_gen.get_next() as stream_id: | |||
yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) | |||
yield self.db_pool.runInteraction( | |||
"delete_pusher", delete_pusher_txn, stream_id | |||
) |
@@ -23,7 +23,7 @@ from canonicaljson import json | |||
from twisted.internet import defer | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.util.id_generators import StreamIdGenerator | |||
from synapse.util.async_helpers import ObservableDeferred | |||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList | |||
@@ -41,7 +41,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
# the abstract methods being implemented. | |||
__metaclass__ = abc.ABCMeta | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) | |||
self._receipts_stream_cache = StreamChangeCache( | |||
@@ -64,7 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
@cached(num_args=2) | |||
def get_receipts_for_room(self, room_id, receipt_type): | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
table="receipts_linearized", | |||
keyvalues={"room_id": room_id, "receipt_type": receipt_type}, | |||
retcols=("user_id", "event_id"), | |||
@@ -73,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
@cached(num_args=3) | |||
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="receipts_linearized", | |||
keyvalues={ | |||
"room_id": room_id, | |||
@@ -87,7 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
@cachedInlineCallbacks(num_args=2) | |||
def get_receipts_for_user(self, user_id, receipt_type): | |||
rows = yield self.db.simple_select_list( | |||
rows = yield self.db_pool.simple_select_list( | |||
table="receipts_linearized", | |||
keyvalues={"user_id": user_id, "receipt_type": receipt_type}, | |||
retcols=("room_id", "event_id"), | |||
@@ -111,7 +111,9 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (user_id,)) | |||
return txn.fetchall() | |||
rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) | |||
rows = yield self.db_pool.runInteraction( | |||
"get_receipts_for_user_with_orderings", f | |||
) | |||
return { | |||
row[0]: { | |||
"event_id": row[1], | |||
@@ -190,11 +192,11 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (room_id, to_key)) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
return rows | |||
rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) | |||
rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) | |||
if not rows: | |||
return [] | |||
@@ -240,9 +242,9 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
txn.execute(sql + clause, [to_key] + list(args)) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
txn_results = yield self.db.runInteraction( | |||
txn_results = yield self.db_pool.runInteraction( | |||
"_get_linearized_receipts_for_rooms", f | |||
) | |||
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
return [r[0] for r in txn] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn | |||
) | |||
@@ -340,7 +342,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
return updates, upper_bound, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_updated_receipts", get_all_updated_receipts_txn | |||
) | |||
@@ -371,7 +373,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
class ReceiptsStore(ReceiptsWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
# We instantiate this first as the ReceiptsWorkerStore constructor | |||
# needs to be able to call get_max_receipt_stream_id | |||
self._receipts_id_gen = StreamIdGenerator( | |||
@@ -393,7 +395,7 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
otherwise, the rx timestamp of the event that the RR corresponds to | |||
(or 0 if the event is unknown) | |||
""" | |||
res = self.db.simple_select_one_txn( | |||
res = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="events", | |||
retcols=["stream_ordering", "received_ts"], | |||
@@ -446,7 +448,7 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
(user_id, room_id, receipt_type), | |||
) | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="receipts_linearized", | |||
keyvalues={ | |||
@@ -506,13 +508,13 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
else: | |||
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) | |||
linearized_event_id = yield self.db.runInteraction( | |||
linearized_event_id = yield self.db_pool.runInteraction( | |||
"insert_receipt_conv", graph_to_linear | |||
) | |||
stream_id_manager = self._receipts_id_gen.get_next() | |||
with stream_id_manager as stream_id: | |||
event_ts = yield self.db.runInteraction( | |||
event_ts = yield self.db_pool.runInteraction( | |||
"insert_linearized_receipt", | |||
self.insert_linearized_receipt_txn, | |||
room_id, | |||
@@ -541,7 +543,7 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
return stream_id, max_persisted_id | |||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"insert_graph_receipt", | |||
self.insert_graph_receipt_txn, | |||
room_id, | |||
@@ -567,7 +569,7 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
self._get_linearized_receipts_for_room.invalidate_many, (room_id,) | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="receipts_graph", | |||
keyvalues={ | |||
@@ -576,7 +578,7 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
"user_id": user_id, | |||
}, | |||
) | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="receipts_graph", | |||
values={ |
@@ -26,7 +26,7 @@ from synapse.api.constants import UserTypes | |||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.types import Cursor | |||
from synapse.storage.util.sequence import build_sequence_generator | |||
from synapse.types import UserID | |||
@@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) | |||
class RegistrationWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -50,7 +50,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
@cached() | |||
def get_user_by_id(self, user_id): | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
retcols=[ | |||
@@ -101,7 +101,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
including the keys `name`, `is_guest`, `device_id`, `token_id`, | |||
`valid_until_ms`. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_user_by_access_token", self._query_for_auth, token | |||
) | |||
@@ -116,7 +116,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
otherwise int representation of the timestamp (as a number of | |||
milliseconds since epoch). | |||
""" | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
retcol="expiration_ts_ms", | |||
@@ -144,7 +144,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
""" | |||
def set_account_validity_for_user_txn(txn): | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn=txn, | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
@@ -158,7 +158,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
txn, self.get_expiration_ts_for_user, (user_id,) | |||
) | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"set_account_validity_for_user", set_account_validity_for_user_txn | |||
) | |||
@@ -174,7 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Raises: | |||
StoreError: The provided token is already set for another user. | |||
""" | |||
yield self.db.simple_update_one( | |||
yield self.db_pool.simple_update_one( | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
updatevalues={"renewal_token": renewal_token}, | |||
@@ -191,7 +191,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
defer.Deferred[str]: The ID of the user to which the token belongs. | |||
""" | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="account_validity", | |||
keyvalues={"renewal_token": renewal_token}, | |||
retcol="user_id", | |||
@@ -210,7 +210,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
defer.Deferred[str]: The renewal token associated with this user ID. | |||
""" | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
retcol="renewal_token", | |||
@@ -236,9 +236,9 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
) | |||
values = [False, now_ms, renew_at] | |||
txn.execute(sql, values) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
res = yield self.db.runInteraction( | |||
res = yield self.db_pool.runInteraction( | |||
"get_users_expiring_soon", | |||
select_users_txn, | |||
self.clock.time_msec(), | |||
@@ -257,7 +257,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
email_sent (bool): Flag which indicates whether a renewal email has been sent | |||
to this user. | |||
""" | |||
yield self.db.simple_update_one( | |||
yield self.db_pool.simple_update_one( | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
updatevalues={"email_sent": email_sent}, | |||
@@ -272,7 +272,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Args: | |||
user_id (str): ID of the user to remove from the account validity table. | |||
""" | |||
yield self.db.simple_delete_one( | |||
yield self.db_pool.simple_delete_one( | |||
table="account_validity", | |||
keyvalues={"user_id": user_id}, | |||
desc="delete_account_validity_for_user", | |||
@@ -287,7 +287,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns (bool): | |||
true iff the user is a server admin, false otherwise. | |||
""" | |||
res = await self.db.simple_select_one_onecol( | |||
res = await self.db_pool.simple_select_one_onecol( | |||
table="users", | |||
keyvalues={"name": user.to_string()}, | |||
retcol="admin", | |||
@@ -307,14 +307,14 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
""" | |||
def set_server_admin_txn(txn): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} | |||
) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_user_by_id, (user.to_string(),) | |||
) | |||
return self.db.runInteraction("set_server_admin", set_server_admin_txn) | |||
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) | |||
def _query_for_auth(self, txn, token): | |||
sql = ( | |||
@@ -326,7 +326,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(sql, (token,)) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if rows: | |||
return rows[0] | |||
@@ -342,7 +342,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[bool]: True if user 'user_type' is null or empty string | |||
""" | |||
res = yield self.db.runInteraction( | |||
res = yield self.db_pool.runInteraction( | |||
"is_real_user", self.is_real_user_txn, user_id | |||
) | |||
return res | |||
@@ -357,12 +357,12 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[bool]: True if user is of type UserTypes.SUPPORT | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"is_support_user", self.is_support_user_txn, user_id | |||
) | |||
def is_real_user_txn(self, txn, user_id): | |||
res = self.db.simple_select_one_onecol_txn( | |||
res = self.db_pool.simple_select_one_onecol_txn( | |||
txn=txn, | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
@@ -372,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
return res is None | |||
def is_support_user_txn(self, txn, user_id): | |||
res = self.db.simple_select_one_onecol_txn( | |||
res = self.db_pool.simple_select_one_onecol_txn( | |||
txn=txn, | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
@@ -391,7 +391,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (user_id,)) | |||
return dict(txn) | |||
return self.db.runInteraction("get_users_by_id_case_insensitive", f) | |||
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) | |||
async def get_user_by_external_id( | |||
self, auth_provider: str, external_id: str | |||
@@ -405,7 +405,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
str|None: the mxid of the user, or None if they are not known | |||
""" | |||
return await self.db.simple_select_one_onecol( | |||
return await self.db_pool.simple_select_one_onecol( | |||
table="user_external_ids", | |||
keyvalues={"auth_provider": auth_provider, "external_id": external_id}, | |||
retcol="user_id", | |||
@@ -419,12 +419,12 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
def _count_users(txn): | |||
txn.execute("SELECT COUNT(*) AS users FROM users") | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if rows: | |||
return rows[0]["users"] | |||
return 0 | |||
ret = yield self.db.runInteraction("count_users", _count_users) | |||
ret = yield self.db_pool.runInteraction("count_users", _count_users) | |||
return ret | |||
def count_daily_user_type(self): | |||
@@ -456,7 +456,9 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
results[row[0]] = row[1] | |||
return results | |||
return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) | |||
return self.db_pool.runInteraction( | |||
"count_daily_user_type", _count_daily_user_type | |||
) | |||
@defer.inlineCallbacks | |||
def count_nonbridged_users(self): | |||
@@ -470,7 +472,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
(count,) = txn.fetchone() | |||
return count | |||
ret = yield self.db.runInteraction("count_users", _count_users) | |||
ret = yield self.db_pool.runInteraction("count_users", _count_users) | |||
return ret | |||
@defer.inlineCallbacks | |||
@@ -479,12 +481,12 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
def _count_users(txn): | |||
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if rows: | |||
return rows[0]["users"] | |||
return 0 | |||
ret = yield self.db.runInteraction("count_real_users", _count_users) | |||
ret = yield self.db_pool.runInteraction("count_real_users", _count_users) | |||
return ret | |||
async def generate_user_id(self) -> str: | |||
@@ -492,7 +494,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: a (hopefully) free localpart | |||
""" | |||
next_id = await self.db.runInteraction( | |||
next_id = await self.db_pool.runInteraction( | |||
"generate_user_id", self._user_id_seq.get_next_id_txn | |||
) | |||
@@ -508,7 +510,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
The user ID or None if no user id/threepid mapping exists | |||
""" | |||
user_id = await self.db.runInteraction( | |||
user_id = await self.db_pool.runInteraction( | |||
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address | |||
) | |||
return user_id | |||
@@ -524,7 +526,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
str|None: user id or None if no user id/threepid mapping exists | |||
""" | |||
ret = self.db.simple_select_one_txn( | |||
ret = self.db_pool.simple_select_one_txn( | |||
txn, | |||
"user_threepids", | |||
{"medium": medium, "address": address}, | |||
@@ -537,7 +539,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): | |||
yield self.db.simple_upsert( | |||
yield self.db_pool.simple_upsert( | |||
"user_threepids", | |||
{"medium": medium, "address": address}, | |||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, | |||
@@ -545,7 +547,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
@defer.inlineCallbacks | |||
def user_get_threepids(self, user_id): | |||
ret = yield self.db.simple_select_list( | |||
ret = yield self.db_pool.simple_select_list( | |||
"user_threepids", | |||
{"user_id": user_id}, | |||
["medium", "address", "validated_at", "added_at"], | |||
@@ -554,7 +556,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
return ret | |||
def user_delete_threepid(self, user_id, medium, address): | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
"user_threepids", | |||
keyvalues={"user_id": user_id, "medium": medium, "address": address}, | |||
desc="user_delete_threepid", | |||
@@ -567,7 +569,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
user_id: The user id to delete all threepids of | |||
""" | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
"user_threepids", | |||
keyvalues={"user_id": user_id}, | |||
desc="user_delete_threepids", | |||
@@ -589,7 +591,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
""" | |||
# We need to use an upsert, in case they user had already bound the | |||
# threepid | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="user_threepid_id_server", | |||
keyvalues={ | |||
"user_id": user_id, | |||
@@ -615,7 +617,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
medium (str): The medium of the threepid (e.g "email") | |||
address (str): The address of the threepid (e.g "bob@example.com") | |||
""" | |||
return self.db.simple_select_list( | |||
return self.db_pool.simple_select_list( | |||
table="user_threepid_id_server", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["medium", "address"], | |||
@@ -636,7 +638,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred | |||
""" | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
table="user_threepid_id_server", | |||
keyvalues={ | |||
"user_id": user_id, | |||
@@ -659,7 +661,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
Returns: | |||
Deferred[list[str]]: Resolves to a list of identity servers | |||
""" | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="user_threepid_id_server", | |||
keyvalues={"user_id": user_id, "medium": medium, "address": address}, | |||
retcol="id_server", | |||
@@ -677,7 +679,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
defer.Deferred(bool): The requested value. | |||
""" | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
retcol="deactivated", | |||
@@ -744,13 +746,13 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
sql += " LIMIT 1" | |||
txn.execute(sql, list(keyvalues.values())) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
return None | |||
return rows[0] | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_threepid_validation_session", get_threepid_validation_session_txn | |||
) | |||
@@ -764,37 +766,37 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
""" | |||
def delete_threepid_session_txn(txn): | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="threepid_validation_token", | |||
keyvalues={"session_id": session_id}, | |||
) | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="threepid_validation_session", | |||
keyvalues={"session_id": session_id}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"delete_threepid_session", delete_threepid_session_txn | |||
) | |||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.clock = hs.get_clock() | |||
self.config = hs.config | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"access_tokens_device_index", | |||
index_name="access_tokens_device_id", | |||
table="access_tokens", | |||
columns=["user_id", "device_id"], | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"users_creation_ts", | |||
index_name="users_creation_ts", | |||
table="users", | |||
@@ -804,13 +806,15 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
# we no longer use refresh tokens, but it's possible that some people | |||
# might have a background update queued to build this index. Just | |||
# clear the background update. | |||
self.db.updates.register_noop_background_update("refresh_tokens_device_index") | |||
self.db_pool.updates.register_noop_background_update( | |||
"refresh_tokens_device_index" | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"user_threepids_grandfather", self._bg_user_threepids_grandfather | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag | |||
) | |||
@@ -843,7 +847,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
(last_user, batch_size), | |||
) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
return True, 0 | |||
@@ -857,7 +861,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
logger.info("Marked %d rows as deactivated", rows_processed_nb) | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} | |||
) | |||
@@ -866,12 +870,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
else: | |||
return False, len(rows) | |||
end, nb_processed = yield self.db.runInteraction( | |||
end, nb_processed = yield self.db_pool.runInteraction( | |||
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn | |||
) | |||
if end: | |||
yield self.db.updates._end_background_update("users_set_deactivated_flag") | |||
yield self.db_pool.updates._end_background_update( | |||
"users_set_deactivated_flag" | |||
) | |||
return nb_processed | |||
@@ -897,17 +903,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): | |||
txn.executemany(sql, [(id_server,) for id_server in id_servers]) | |||
if id_servers: | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn | |||
) | |||
yield self.db.updates._end_background_update("user_threepids_grandfather") | |||
yield self.db_pool.updates._end_background_update("user_threepids_grandfather") | |||
return 1 | |||
class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RegistrationStore, self).__init__(database, db_conn, hs) | |||
self._account_validity = hs.config.account_validity | |||
@@ -947,7 +953,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
next_id = self._access_tokens_id_gen.get_next() | |||
yield self.db.simple_insert( | |||
yield self.db_pool.simple_insert( | |||
"access_tokens", | |||
{ | |||
"id": next_id, | |||
@@ -992,7 +998,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
Returns: | |||
Deferred | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"register_user", | |||
self._register_user, | |||
user_id, | |||
@@ -1026,7 +1032,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
# Ensure that the guest user actually exists | |||
# ``allow_none=False`` makes this raise an exception | |||
# if the row isn't in the database. | |||
self.db.simple_select_one_txn( | |||
self.db_pool.simple_select_one_txn( | |||
txn, | |||
"users", | |||
keyvalues={"name": user_id, "is_guest": 1}, | |||
@@ -1034,7 +1040,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
allow_none=False, | |||
) | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
"users", | |||
keyvalues={"name": user_id, "is_guest": 1}, | |||
@@ -1048,7 +1054,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
}, | |||
) | |||
else: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"users", | |||
values={ | |||
@@ -1103,7 +1109,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
external_id: id on that system | |||
user_id: complete mxid that it is mapped to | |||
""" | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="user_external_ids", | |||
values={ | |||
"auth_provider": auth_provider, | |||
@@ -1121,12 +1127,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
def user_set_password_hash_txn(txn): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, "users", {"name": user_id}, {"password_hash": password_hash} | |||
) | |||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"user_set_password_hash", user_set_password_hash_txn | |||
) | |||
@@ -1143,7 +1149,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
def f(txn): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
@@ -1151,7 +1157,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | |||
return self.db.runInteraction("user_set_consent_version", f) | |||
return self.db_pool.runInteraction("user_set_consent_version", f) | |||
def user_set_consent_server_notice_sent(self, user_id, consent_version): | |||
"""Updates the user table to record that we have sent the user a server | |||
@@ -1167,7 +1173,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
def f(txn): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
@@ -1175,7 +1181,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | |||
return self.db.runInteraction("user_set_consent_server_notice_sent", f) | |||
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) | |||
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): | |||
""" | |||
@@ -1221,11 +1227,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
return tokens_and_devices | |||
return self.db.runInteraction("user_delete_access_tokens", f) | |||
return self.db_pool.runInteraction("user_delete_access_tokens", f) | |||
def delete_access_token(self, access_token): | |||
def f(txn): | |||
self.db.simple_delete_one_txn( | |||
self.db_pool.simple_delete_one_txn( | |||
txn, table="access_tokens", keyvalues={"token": access_token} | |||
) | |||
@@ -1233,11 +1239,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
txn, self.get_user_by_access_token, (access_token,) | |||
) | |||
return self.db.runInteraction("delete_access_token", f) | |||
return self.db_pool.runInteraction("delete_access_token", f) | |||
@cachedInlineCallbacks() | |||
def is_guest(self, user_id): | |||
res = yield self.db.simple_select_one_onecol( | |||
res = yield self.db_pool.simple_select_one_onecol( | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
retcol="is_guest", | |||
@@ -1252,7 +1258,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
Adds a user to the table of users who need to be parted from all the rooms they're | |||
in | |||
""" | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
"users_pending_deactivation", | |||
values={"user_id": user_id}, | |||
desc="add_user_pending_deactivation", | |||
@@ -1265,7 +1271,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
# XXX: This should be simple_delete_one but we failed to put a unique index on | |||
# the table, so somehow duplicate entries have ended up in it. | |||
return self.db.simple_delete( | |||
return self.db_pool.simple_delete( | |||
"users_pending_deactivation", | |||
keyvalues={"user_id": user_id}, | |||
desc="del_user_pending_deactivation", | |||
@@ -1276,7 +1282,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
Gets one user from the table of users waiting to be parted from all the rooms | |||
they're in. | |||
""" | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
"users_pending_deactivation", | |||
keyvalues={}, | |||
retcol="user_id", | |||
@@ -1306,7 +1312,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
# Insert everything into a transaction in order to run atomically | |||
def validate_threepid_session_txn(txn): | |||
row = self.db.simple_select_one_txn( | |||
row = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="threepid_validation_session", | |||
keyvalues={"session_id": session_id}, | |||
@@ -1324,7 +1330,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
400, "This client_secret does not match the provided session_id" | |||
) | |||
row = self.db.simple_select_one_txn( | |||
row = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="threepid_validation_token", | |||
keyvalues={"session_id": session_id, "token": token}, | |||
@@ -1349,7 +1355,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
# Looks good. Validate the session | |||
self.db.simple_update_txn( | |||
self.db_pool.simple_update_txn( | |||
txn, | |||
table="threepid_validation_session", | |||
keyvalues={"session_id": session_id}, | |||
@@ -1359,7 +1365,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
return next_link | |||
# Return next_link if it exists | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"validate_threepid_session_txn", validate_threepid_session_txn | |||
) | |||
@@ -1392,7 +1398,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
if validated_at: | |||
insertion_values["validated_at"] = validated_at | |||
return self.db.simple_upsert( | |||
return self.db_pool.simple_upsert( | |||
table="threepid_validation_session", | |||
keyvalues={"session_id": session_id}, | |||
values={"last_send_attempt": send_attempt}, | |||
@@ -1430,7 +1436,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
def start_or_continue_validation_session_txn(txn): | |||
# Create or update a validation session | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="threepid_validation_session", | |||
keyvalues={"session_id": session_id}, | |||
@@ -1443,7 +1449,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
# Create a new validation token with this session ID | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="threepid_validation_token", | |||
values={ | |||
@@ -1454,7 +1460,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
}, | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"start_or_continue_validation_session", | |||
start_or_continue_validation_session_txn, | |||
) | |||
@@ -1469,7 +1475,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
""" | |||
return txn.execute(sql, (ts,)) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"cull_expired_threepid_validation_tokens", | |||
cull_expired_threepid_validation_tokens_txn, | |||
self.clock.time_msec(), | |||
@@ -1484,7 +1490,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
deactivated (bool): The value to set for `deactivated`. | |||
""" | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"set_user_deactivated_status", | |||
self.set_user_deactivated_status_txn, | |||
user_id, | |||
@@ -1492,7 +1498,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn=txn, | |||
table="users", | |||
keyvalues={"name": user_id}, | |||
@@ -1520,14 +1526,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
) | |||
txn.execute(sql, []) | |||
res = self.db.cursor_to_dict(txn) | |||
res = self.db_pool.cursor_to_dict(txn) | |||
if res: | |||
for user in res: | |||
self.set_expiration_date_for_user_txn( | |||
txn, user["name"], use_delta=True | |||
) | |||
yield self.db.runInteraction( | |||
yield self.db_pool.runInteraction( | |||
"get_users_with_no_expiration_date", | |||
select_users_with_no_expiration_date_txn, | |||
) | |||
@@ -1551,7 +1557,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
expiration_ts, | |||
) | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
"account_validity", | |||
keyvalues={"user_id": user_id}, |
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) | |||
class RejectionsStore(SQLBaseStore): | |||
def get_rejection_reason(self, event_id): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="rejections", | |||
retcol="reason", | |||
keyvalues={"event_id": event_id}, |
@@ -19,7 +19,7 @@ import attr | |||
from synapse.api.constants import RelationTypes | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.stream import generate_pagination_where_clause | |||
from synapse.storage.databases.main.stream import generate_pagination_where_clause | |||
from synapse.storage.relations import ( | |||
AggregationPaginationToken, | |||
PaginationChunk, | |||
@@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_recent_references_for_event", _get_recent_references_for_event_txn | |||
) | |||
@@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn | |||
) | |||
@@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
if row: | |||
return row[0] | |||
edit_id = yield self.db.runInteraction( | |||
edit_id = yield self.db_pool.runInteraction( | |||
"get_applicable_edit", _get_applicable_edit_txn | |||
) | |||
@@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
return bool(txn.fetchone()) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event | |||
) | |||
@@ -27,8 +27,8 @@ from synapse.api.constants import EventTypes | |||
from synapse.api.errors import StoreError | |||
from synapse.api.room_versions import RoomVersion, RoomVersions | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.data_stores.main.search import SearchStore | |||
from synapse.storage.database import Database, LoggingTransaction | |||
from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.databases.main.search import SearchStore | |||
from synapse.types import ThirdPartyInstanceID | |||
from synapse.util.caches.descriptors import cached | |||
@@ -73,7 +73,7 @@ class RoomSortOrder(Enum): | |||
class RoomWorkerStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomWorkerStore, self).__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -86,7 +86,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
Returns: | |||
A dict containing the room information, or None if the room is unknown. | |||
""" | |||
return self.db.simple_select_one( | |||
return self.db_pool.simple_select_one( | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("room_id", "is_public", "creator"), | |||
@@ -118,7 +118,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
txn.execute(sql, [room_id]) | |||
# Catch error if sql returns empty result to return "None" instead of an error | |||
try: | |||
res = self.db.cursor_to_dict(txn)[0] | |||
res = self.db_pool.cursor_to_dict(txn)[0] | |||
except IndexError: | |||
return None | |||
@@ -126,12 +126,12 @@ class RoomWorkerStore(SQLBaseStore): | |||
res["public"] = bool(res["public"]) | |||
return res | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_room_with_stats", get_room_with_stats_txn, room_id | |||
) | |||
def get_public_room_ids(self): | |||
return self.db.simple_select_onecol( | |||
return self.db_pool.simple_select_onecol( | |||
table="rooms", | |||
keyvalues={"is_public": True}, | |||
retcol="room_id", | |||
@@ -188,7 +188,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
txn.execute(sql, query_args) | |||
return txn.fetchone()[0] | |||
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) | |||
return self.db_pool.runInteraction( | |||
"count_public_rooms", _count_public_rooms_txn | |||
) | |||
async def get_largest_public_rooms( | |||
self, | |||
@@ -320,21 +322,21 @@ class RoomWorkerStore(SQLBaseStore): | |||
def _get_largest_public_rooms_txn(txn): | |||
txn.execute(sql, query_args) | |||
results = self.db.cursor_to_dict(txn) | |||
results = self.db_pool.cursor_to_dict(txn) | |||
if not forwards: | |||
results.reverse() | |||
return results | |||
ret_val = await self.db.runInteraction( | |||
ret_val = await self.db_pool.runInteraction( | |||
"get_largest_public_rooms", _get_largest_public_rooms_txn | |||
) | |||
return ret_val | |||
@cached(max_entries=10000) | |||
def is_room_blocked(self, room_id): | |||
return self.db.simple_select_one_onecol( | |||
return self.db_pool.simple_select_one_onecol( | |||
table="blocked_rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcol="1", | |||
@@ -502,7 +504,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
room_count = txn.fetchone() | |||
return rooms, room_count[0] | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_rooms_paginate", _get_rooms_paginate_txn, | |||
) | |||
@@ -519,7 +521,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
of RatelimitOverride are None or 0 then ratelimitng has been | |||
disabled for that user entirely. | |||
""" | |||
row = await self.db.simple_select_one( | |||
row = await self.db_pool.simple_select_one( | |||
table="ratelimit_override", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("messages_per_second", "burst_count"), | |||
@@ -561,9 +563,9 @@ class RoomWorkerStore(SQLBaseStore): | |||
(room_id,), | |||
) | |||
return self.db.cursor_to_dict(txn) | |||
return self.db_pool.cursor_to_dict(txn) | |||
ret = await self.db.runInteraction( | |||
ret = await self.db_pool.runInteraction( | |||
"get_retention_policy_for_room", get_retention_policy_for_room_txn, | |||
) | |||
@@ -613,7 +615,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
return local_media_mxcs, remote_media_mxcs | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_media_ids_in_room", _get_media_mxcs_in_room_txn | |||
) | |||
@@ -630,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
txn, local_mxcs, remote_mxcs, quarantined_by | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"quarantine_media_in_room", _quarantine_media_in_room_txn | |||
) | |||
@@ -714,7 +716,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
txn, local_mxcs, remote_mxcs, quarantined_by | |||
) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"quarantine_media_by_user", _quarantine_media_by_id_txn | |||
) | |||
@@ -730,7 +732,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) | |||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"quarantine_media_by_user", _quarantine_media_by_user_txn | |||
) | |||
@@ -848,7 +850,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
return updates, upto_token, limited | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_all_new_public_rooms", get_all_new_public_rooms | |||
) | |||
@@ -857,21 +859,21 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" | |||
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
"insert_room_retention", self._background_insert_retention, | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, | |||
self._remove_tombstoned_rooms_from_directory, | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
self.ADD_ROOMS_ROOM_VERSION_COLUMN, | |||
self._background_add_rooms_room_version_column, | |||
) | |||
@@ -900,7 +902,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
(last_room, batch_size), | |||
) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
return True | |||
@@ -912,7 +914,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
ev = db_to_json(row["json"]) | |||
retention_policy = ev["content"] | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn=txn, | |||
table="room_retention", | |||
values={ | |||
@@ -925,7 +927,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
logger.info("Inserted %d rows into room_retention", len(rows)) | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} | |||
) | |||
@@ -934,12 +936,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
else: | |||
return False | |||
end = await self.db.runInteraction( | |||
end = await self.db_pool.runInteraction( | |||
"insert_room_retention", _background_insert_retention_txn, | |||
) | |||
if end: | |||
await self.db.updates._end_background_update("insert_room_retention") | |||
await self.db_pool.updates._end_background_update("insert_room_retention") | |||
return batch_size | |||
@@ -983,7 +985,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
# mainly for paranoia as much badness would happen if we don't | |||
# insert the row and then try and get the room version for the | |||
# room. | |||
self.db.simple_upsert_txn( | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
@@ -992,19 +994,19 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
new_last_room_id = room_id | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} | |||
) | |||
return False | |||
end = await self.db.runInteraction( | |||
end = await self.db_pool.runInteraction( | |||
"_background_add_rooms_room_version_column", | |||
_background_add_rooms_room_version_column_txn, | |||
) | |||
if end: | |||
await self.db.updates._end_background_update( | |||
await self.db_pool.updates._end_background_update( | |||
self.ADD_ROOMS_ROOM_VERSION_COLUMN | |||
) | |||
@@ -1038,12 +1040,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
return [row[0] for row in txn] | |||
rooms = await self.db.runInteraction( | |||
rooms = await self.db_pool.runInteraction( | |||
"get_tombstoned_directory_rooms", _get_rooms | |||
) | |||
if not rooms: | |||
await self.db.updates._end_background_update( | |||
await self.db_pool.updates._end_background_update( | |||
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE | |||
) | |||
return 0 | |||
@@ -1052,7 +1054,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
logger.info("Removing tombstoned room %s from the directory", room_id) | |||
await self.set_room_is_public(room_id, False) | |||
await self.db.updates._background_update_progress( | |||
await self.db_pool.updates._background_update_progress( | |||
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} | |||
) | |||
@@ -1068,7 +1070,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomStore, self).__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -1079,7 +1081,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
Called when we join a room over federation, and overwrites any room version | |||
currently in the table. | |||
""" | |||
await self.db.simple_upsert( | |||
await self.db_pool.simple_upsert( | |||
desc="upsert_room_on_join", | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
@@ -1111,7 +1113,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
try: | |||
def store_room_txn(txn, next_id): | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"rooms", | |||
{ | |||
@@ -1122,7 +1124,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
}, | |||
) | |||
if is_public: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="public_room_list_stream", | |||
values={ | |||
@@ -1133,7 +1135,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
await self.db.runInteraction("store_room_txn", store_room_txn, next_id) | |||
await self.db_pool.runInteraction( | |||
"store_room_txn", store_room_txn, next_id | |||
) | |||
except Exception as e: | |||
logger.error("store_room with room_id=%s failed: %s", room_id, e) | |||
raise StoreError(500, "Problem creating room.") | |||
@@ -1143,7 +1147,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
When we receive an invite over federation, store the version of the room if we | |||
don't already know the room version. | |||
""" | |||
await self.db.simple_upsert( | |||
await self.db_pool.simple_upsert( | |||
desc="maybe_store_room_on_invite", | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
@@ -1160,14 +1164,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
async def set_room_is_public(self, room_id, is_public): | |||
def set_room_is_public_txn(txn, next_id): | |||
self.db.simple_update_one_txn( | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
updatevalues={"is_public": is_public}, | |||
) | |||
entries = self.db.simple_select_list_txn( | |||
entries = self.db_pool.simple_select_list_txn( | |||
txn, | |||
table="public_room_list_stream", | |||
keyvalues={ | |||
@@ -1185,7 +1189,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
add_to_stream = bool(entries[-1]["visibility"]) != is_public | |||
if add_to_stream: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="public_room_list_stream", | |||
values={ | |||
@@ -1198,7 +1202,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
await self.db.runInteraction( | |||
await self.db_pool.runInteraction( | |||
"set_room_is_public", set_room_is_public_txn, next_id | |||
) | |||
self.hs.get_notifier().on_new_replication_data() | |||
@@ -1224,7 +1228,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
def set_room_is_public_appservice_txn(txn, next_id): | |||
if is_public: | |||
try: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="appservice_room_list", | |||
values={ | |||
@@ -1237,7 +1241,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
# We've already inserted, nothing to do. | |||
return | |||
else: | |||
self.db.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="appservice_room_list", | |||
keyvalues={ | |||
@@ -1247,7 +1251,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
}, | |||
) | |||
entries = self.db.simple_select_list_txn( | |||
entries = self.db_pool.simple_select_list_txn( | |||
txn, | |||
table="public_room_list_stream", | |||
keyvalues={ | |||
@@ -1265,7 +1269,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
add_to_stream = bool(entries[-1]["visibility"]) != is_public | |||
if add_to_stream: | |||
self.db.simple_insert_txn( | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
table="public_room_list_stream", | |||
values={ | |||
@@ -1278,7 +1282,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
await self.db.runInteraction( | |||
await self.db_pool.runInteraction( | |||
"set_room_is_public_appservice", | |||
set_room_is_public_appservice_txn, | |||
next_id, | |||
@@ -1295,13 +1299,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
row = txn.fetchone() | |||
return row[0] or 0 | |||
return self.db.runInteraction("get_rooms", f) | |||
return self.db_pool.runInteraction("get_rooms", f) | |||
def add_event_report( | |||
self, room_id, event_id, user_id, reason, content, received_ts | |||
): | |||
next_id = self._event_reports_id_gen.get_next() | |||
return self.db.simple_insert( | |||
return self.db_pool.simple_insert( | |||
table="event_reports", | |||
values={ | |||
"id": next_id, | |||
@@ -1325,14 +1329,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
room_id: Room to block | |||
user_id: Who blocked it | |||
""" | |||
await self.db.simple_upsert( | |||
await self.db_pool.simple_upsert( | |||
table="blocked_rooms", | |||
keyvalues={"room_id": room_id}, | |||
values={}, | |||
insertion_values={"user_id": user_id}, | |||
desc="block_room", | |||
) | |||
await self.db.runInteraction( | |||
await self.db_pool.runInteraction( | |||
"block_room_invalidation", | |||
self._invalidate_cache_and_stream, | |||
self.is_room_blocked, | |||
@@ -1388,7 +1392,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
txn.execute(sql, args) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
rooms_dict = {} | |||
for row in rows: | |||
@@ -1404,7 +1408,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
txn.execute(sql) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
# If a room isn't already in the dict (i.e. it doesn't have a retention | |||
# policy in its state), add it with a null policy. | |||
@@ -1417,7 +1421,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
return rooms_dict | |||
rooms = await self.db.runInteraction( | |||
rooms = await self.db_pool.runInteraction( | |||
"get_rooms_for_retention_period_in_range", | |||
get_rooms_for_retention_period_in_range_txn, | |||
) |
@@ -28,8 +28,8 @@ from synapse.storage._base import ( | |||
db_to_json, | |||
make_in_list_sql_clause, | |||
) | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.database import Database | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.engines import Sqlite3Engine | |||
from synapse.storage.roommember import ( | |||
GetRoomsForUserWithStreamOrdering, | |||
@@ -51,7 +51,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" | |||
class RoomMemberWorkerStore(EventsWorkerStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) | |||
# Is the current_state_events.membership up to date? Or is the | |||
@@ -116,7 +116,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
txn.execute(query) | |||
return list(txn)[0][0] | |||
count = yield self.db.runInteraction("get_known_servers", _transact) | |||
count = yield self.db_pool.runInteraction("get_known_servers", _transact) | |||
# We always know about ourselves, even if we have nothing in | |||
# room_memberships (for example, the server is new). | |||
@@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
membership column is up to date | |||
""" | |||
pending_update = self.db.simple_select_one_txn( | |||
pending_update = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="background_updates", | |||
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, | |||
@@ -144,14 +144,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
15.0, | |||
run_as_background_process, | |||
"_check_safe_current_state_events_membership_updated", | |||
self.db.runInteraction, | |||
self.db_pool.runInteraction, | |||
"_check_safe_current_state_events_membership_updated", | |||
self._check_safe_current_state_events_membership_updated_txn, | |||
) | |||
@cached(max_entries=100000, iterable=True) | |||
def get_users_in_room(self, room_id): | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_users_in_room", self.get_users_in_room_txn, room_id | |||
) | |||
@@ -259,7 +259,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return res | |||
return self.db.runInteraction("get_room_summary", _get_room_summary_txn) | |||
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) | |||
def _get_user_counts_in_room_txn(self, txn, room_id): | |||
""" | |||
@@ -332,7 +332,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
if not membership_list: | |||
return defer.succeed(None) | |||
rooms = yield self.db.runInteraction( | |||
rooms = yield self.db_pool.runInteraction( | |||
"get_rooms_for_local_user_where_membership_is", | |||
self._get_rooms_for_local_user_where_membership_is_txn, | |||
user_id, | |||
@@ -369,7 +369,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
txn.execute(sql, (user_id, *args)) | |||
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] | |||
results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] | |||
return results | |||
@@ -388,7 +388,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
the rooms the user is in currently, along with the stream ordering | |||
of the most recent join for that user and room. | |||
""" | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_rooms_for_user_with_stream_ordering", | |||
self._get_rooms_for_user_with_stream_ordering_txn, | |||
user_id, | |||
@@ -453,7 +453,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return {row[0] for row in txn} | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_users_server_still_shares_room_with", | |||
_get_users_server_still_shares_room_with_txn, | |||
) | |||
@@ -624,7 +624,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
to `user_id` and ProfileInfo (or None if not join event). | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = yield self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=event_ids, | |||
@@ -664,7 +664,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
# the returned user actually has the correct domain. | |||
like_clause = "%:" + host | |||
rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) | |||
rows = yield self.db_pool.execute( | |||
"is_host_joined", None, sql, room_id, like_clause | |||
) | |||
if not rows: | |||
return False | |||
@@ -704,7 +706,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
# the returned user actually has the correct domain. | |||
like_clause = "%:" + host | |||
rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) | |||
rows = yield self.db_pool.execute( | |||
"was_host_joined", None, sql, room_id, like_clause | |||
) | |||
if not rows: | |||
return False | |||
@@ -774,7 +778,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
rows = txn.fetchall() | |||
return rows[0][0] | |||
count = yield self.db.runInteraction("did_forget_membership", f) | |||
count = yield self.db_pool.runInteraction("did_forget_membership", f) | |||
return count == 0 | |||
@cached() | |||
@@ -811,7 +815,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
txn.execute(sql, (user_id,)) | |||
return {row[0] for row in txn if row[1] == 0} | |||
return self.db.runInteraction( | |||
return self.db_pool.runInteraction( | |||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn | |||
) | |||
@@ -826,7 +830,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
Deferred[set[str]]: Set of room IDs. | |||
""" | |||
room_ids = yield self.db.simple_select_onecol( | |||
room_ids = yield self.db_pool.simple_select_onecol( | |||
table="room_memberships", | |||
keyvalues={"membership": Membership.JOIN, "user_id": user_id}, | |||
retcol="room_id", | |||
@@ -841,7 +845,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
"""Get user_id and membership of a set of event IDs. | |||
""" | |||
return self.db.simple_select_many_batch( | |||
return self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=member_event_ids, | |||
@@ -877,23 +881,23 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return bool(txn.fetchone()) | |||
return await self.db.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"is_local_host_in_room_ignoring_users", | |||
_is_local_host_in_room_ignoring_users_txn, | |||
) | |||
class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile | |||
) | |||
self.db.updates.register_background_update_handler( | |||
self.db_pool.updates.register_background_update_handler( | |||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, | |||
self._background_current_state_membership, | |||
) | |||
self.db.updates.register_background_index_update( | |||
self.db_pool.updates.register_background_index_update( | |||
"room_membership_forgotten_idx", | |||
index_name="room_memberships_user_room_forgotten", | |||
table="room_memberships", | |||
@@ -926,7 +930,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) | |||
rows = self.db.cursor_to_dict(txn) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
return 0 | |||
@@ -961,18 +965,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
"max_stream_id_exclusive": min_stream_id, | |||
} | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress | |||
) | |||
return len(rows) | |||
result = yield self.db.runInteraction( | |||
result = yield self.db_pool.runInteraction( | |||
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn | |||
) | |||
if not result: | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.updates._end_background_update( | |||
_MEMBERSHIP_PROFILE_UPDATE_NAME | |||
) | |||
@@ -1013,7 +1017,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
last_processed_room = next_room | |||
self.db.updates._background_update_progress_txn( | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, | |||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, | |||
{"last_processed_room": last_processed_room}, | |||
@@ -1025,14 +1029,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
# string, which will compare before all room IDs correctly. | |||
last_processed_room = progress.get("last_processed_room", "") | |||
row_count, finished = yield self.db.runInteraction( | |||
row_count, finished = yield self.db_pool.runInteraction( | |||
"_background_current_state_membership_update", | |||
_background_current_state_membership_txn, | |||
last_processed_room, | |||
) | |||
if finished: | |||
yield self.db.updates._end_background_update( | |||
yield self.db_pool.updates._end_background_update( | |||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME | |||
) | |||
@@ -1040,7 +1044,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
def __init__(self, database: Database, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super(RoomMemberStore, self).__init__(database, db_conn, hs) | |||
def forget(self, user_id, room_id): | |||
@@ -1064,7 +1068,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
txn, self.get_forgotten_rooms_for_user, (user_id,) | |||
) | |||
return self.db.runInteraction("forget_membership", f) | |||
return self.db_pool.runInteraction("forget_membership", f) | |||
class _JoinedHostsCache(object): |