瀏覽代碼

Remove manys calls to cursor_to_dict (#16431)

This avoids calling cursor_to_dict and then immediately
unpacking the values in the dict for other users. By not
creating the intermediate dictionary we can avoid allocating
the dictionary and strings for the keys, which should generally
be more performant.

Additionally this improves type hints by avoid Dict[str, Any]
dictionaries coming out of the database layer.
tags/v1.95.0rc1
Patrick Cloke 7 月之前
committed by GitHub
父節點
當前提交
fa907025f4
沒有發現已知的金鑰在資料庫的簽署中 GPG Key ID: 4AEE18F83AFDEB23
共有 16 個文件被更改,包括 320 次插入228 次删除
  1. +1
    -1
      changelog.d/16429.misc
  2. +1
    -0
      changelog.d/16431.misc
  3. +1
    -1
      synapse/push/__init__.py
  4. +5
    -6
      synapse/storage/databases/main/account_data.py
  5. +7
    -22
      synapse/storage/databases/main/appservice.py
  6. +6
    -6
      synapse/storage/databases/main/devices.py
  7. +6
    -16
      synapse/storage/databases/main/end_to_end_keys.py
  8. +3
    -6
      synapse/storage/databases/main/events.py
  9. +13
    -5
      synapse/storage/databases/main/presence.py
  10. +93
    -28
      synapse/storage/databases/main/pusher.py
  11. +40
    -32
      synapse/storage/databases/main/receipts.py
  12. +77
    -56
      synapse/storage/databases/main/registration.py
  13. +19
    -23
      synapse/storage/databases/main/room.py
  14. +4
    -6
      synapse/storage/databases/main/roommember.py
  15. +11
    -9
      synapse/storage/databases/main/search.py
  16. +33
    -11
      synapse/storage/databases/main/task_scheduler.py

+ 1
- 1
changelog.d/16429.misc 查看文件

@@ -1 +1 @@
Reduce the size of each replication command instance.
Reduce memory allocations.

+ 1
- 0
changelog.d/16431.misc 查看文件

@@ -0,0 +1 @@
Reduce memory allocations.

+ 1
- 1
synapse/push/__init__.py 查看文件

@@ -101,7 +101,7 @@ if TYPE_CHECKING:
class PusherConfig:
"""Parameters necessary to configure a pusher."""

id: Optional[str]
id: Optional[int]
user_name: str

profile_tag: str


+ 5
- 6
synapse/storage/databases/main/account_data.py 查看文件

@@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"

txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)

return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in txn
}

return await self.db_pool.runInteraction(
@@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"

txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)

by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
for room_id, account_data_type, content in txn:
room_data = by_room.setdefault(room_id, {})

room_data[row["account_data_type"]] = db_to_json(row["content"])
room_data[account_data_type] = db_to_json(content)

return by_room



+ 7
- 22
synapse/storage/databases/main/appservice.py 查看文件

@@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Pattern,
Sequence,
Tuple,
cast,
)
from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast

from synapse.appservice import (
ApplicationService,
@@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(

def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
"SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None

entry = rows[0]

return entry
return cast(Optional[Tuple[int, str]], txn.fetchone())

entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None

event_ids = db_to_json(entry["event_ids"])
txn_id, event_ids_str = entry

event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)

# TODO: to-device messages, one-time key counts, device list summaries and unused
@@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
id=entry["txn_id"],
id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],


+ 6
- 6
synapse/storage/databases/main/devices.py 查看文件

@@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):

def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, str]], txn.fetchall())

rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)

devices: Dict[str, List[str]] = {}
for row in rows:
for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
if self.hs.is_mine_id(row["user_id"]):
user_devices = devices.setdefault(row["user_id"], [])
user_devices.append(row["device_id"])
if self.hs.is_mine_id(user_id):
user_devices = devices.setdefault(user_id, [])
user_devices.append(device_id)

return devices



+ 6
- 16
synapse/storage/databases/main/end_to_end_keys.py 查看文件

@@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}

txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)

for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
user_keys[key_type] = key
user_keys[key_type] = db_to_json(key_data)

return result

@@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)

txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)

# and add the signatures to the appropriate keys
for row in rows:
key_id: str = row["key_id"]
target_user_id: str = row["target_user_id"]
target_device_id: str = row["target_device_id"]
for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
user_sigs[key_id] = signature
else:
signatures[from_user_id] = {key_id: row["signature"]}
signatures[from_user_id] = {key_id: signature}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}
target_user_key["signatures"] = {from_user_id: {key_id: signature}}

return keys



+ 3
- 6
synapse/storage/databases/main/events.py 查看文件

@@ -1654,8 +1654,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []

rows = []

ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@@ -1676,10 +1674,9 @@ class PersistEventsStore:
)

txn.execute(sql + clause, args)
rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
for event_id, redacts, rejects in txn:
event = ev_map[event_id]
if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))

async def external_prefill() -> None:


+ 13
- 5
synapse/storage/databases/main/presence.py 查看文件

@@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)

txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
txn.close()

for row in rows:
row["currently_active"] = bool(row["currently_active"])

return [UserPresenceState(**row) for row in rows]
return [
UserPresenceState(
user_id=user_id,
state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
]

def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup


+ 93
- 28
synapse/storage/databases/main/pusher.py 查看文件

@@ -47,6 +47,27 @@ if TYPE_CHECKING:

logger = logging.getLogger(__name__)

# The type of a row in the pushers table.
PusherRow = Tuple[
int, # id
str, # user_name
Optional[int], # access_token
str, # profile_tag
str, # kind
str, # app_id
str, # app_display_name
str, # device_display_name
str, # pushkey
int, # ts
str, # lang
str, # data
int, # last_stream_ordering
int, # last_success
int, # failing_since
bool, # enabled
str, # device_id
]


class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)

def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
def _decode_pushers_rows(
self,
rows: Iterable[PusherRow],
) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table

Drops any rows whose data cannot be decoded
"""
for r in rows:
data_json = r["data"]
for (
id,
user_name,
access_token,
profile_tag,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
ts,
lang,
data,
last_stream_ordering,
last_success,
failing_since,
enabled,
device_id,
) in rows:
try:
r["data"] = db_to_json(data_json)
data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
data_json,
id,
data,
e.args[0],
)
continue

# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
r["enabled"] = bool(r["enabled"])

yield PusherConfig(**r)
yield PusherConfig(
id=id,
user_name=user_name,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
ts=ts,
lang=lang,
data=data_json,
last_stream_ordering=last_stream_ordering,
last_success=last_success,
failing_since=failing_since,
# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
enabled=bool(enabled),
device_id=device_id,
access_token=access_token,
)

def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""

def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):

txn.execute(sql, list(keyvalues.values()))

return self.db_pool.cursor_to_dict(txn)
return cast(List[PusherRow], txn.fetchall())

ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)

async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
rows = self.db_pool.cursor_to_dict(txn)

return self._decode_pushers_rows(rows)
def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
txn.execute(
"""
SELECT id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since,
enabled, device_id
FROM pushers WHERE COALESCE(enabled, TRUE)
"""
)
return cast(List[PusherRow], txn.fetchall())

return await self.db_pool.runInteraction(
"get_enabled_pushers", get_enabled_pushers_txn
return self._decode_pushers_rows(
await self.db_pool.runInteraction(
"get_enabled_pushers", get_enabled_pushers_txn
)
)

async def get_all_updated_pushers_rows(
@@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
)

async def get_throttle_params_by_room(
self, pusher_id: str
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
@@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room

async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)

rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if len(rows) == 0:
return 0

@@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
key_values=[(row["pusher_id"],) for row in rows],
key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
(row["pusher_device_id"] or row["token_device_id"], None)
for row in rows
(pusher_device_id or token_device_id, None)
for _, pusher_device_id, token_device_id in rows
],
)

self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)

return len(rows)


+ 40
- 32
synapse/storage/databases/main/receipts.py 查看文件

@@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""

def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)

txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)

txn.execute(sql, (room_id, to_key))

rows = self.db_pool.cursor_to_dict(txn)

return rows
return cast(List[Tuple[str, str, str, str]], txn.fetchall())

rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)

@@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []

content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
] = db_to_json(row["data"])
for receipt_type, user_id, event_id, data in rows:
content.setdefault(event_id, {}).setdefault(receipt_type, {})[
user_id
] = db_to_json(data)

return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]

@@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not room_ids:
return {}

def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
@@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
"""

@@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):

txn.execute(sql + clause, [to_key] + list(args))

return self.db_pool.cursor_to_dict(txn)
return cast(
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
)

txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)

results: JsonDict = {}
for row in txn_results:
for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
room_id,
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)

# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
event_entry = room_event["content"].setdefault(event_id, {})
receipt_type_dict = event_entry.setdefault(receipt_type, {})

receipt_type[row["user_id"]] = db_to_json(row["data"])
if row["thread_id"]:
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
receipt_type_dict[user_id] = db_to_json(data)
if thread_id:
receipt_type_dict[user_id]["thread_id"] = thread_id

results = {
room_id: [results[room_id]] if room_id in results else []
@@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""

def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):

txn.execute(sql, [to_key])

return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())

txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)

results: JsonDict = {}
for row in txn_results:
for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
room_id,
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)

# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
event_entry = room_event["content"].setdefault(event_id, {})
receipt_type_dict = event_entry.setdefault(receipt_type, {})

receipt_type[row["user_id"]] = db_to_json(row["data"])
receipt_type_dict[user_id] = db_to_json(data)

return results



+ 77
- 56
synapse/storage/databases/main/registration.py 查看文件

@@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""

def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)

rows = self.db_pool.cursor_to_dict(txn)

if len(rows) == 0:
row = txn.fetchone()
if not row:
return None

return rows[0]
(
name,
is_guest,
admin,
consent_version,
consent_ts,
consent_server_notice_sent,
appservice_id,
creation_ts,
user_type,
deactivated,
shadow_banned,
approved,
locked,
) = row

return UserInfo(
appservice_id=appservice_id,
consent_server_notice_sent=consent_server_notice_sent,
consent_version=consent_version,
consent_ts=consent_ts,
creation_ts=creation_ts,
is_admin=bool(admin),
is_deactivated=bool(deactivated),
is_guest=bool(is_guest),
is_shadow_banned=bool(shadow_banned),
user_id=UserID.from_string(name),
user_type=user_type,
approved=bool(approved),
locked=bool(locked),
)

row = await self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is None:
return None

return UserInfo(
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)

async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""

txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)

if rows:
row = rows[0]

# This field is nullable, ensure it comes out as a boolean
if row["token_used"] is None:
row["token_used"] = False
row = txn.fetchone()

return TokenLookupResult(**row)
if row:
(
user_id,
is_guest,
shadow_banned,
token_id,
device_id,
valid_until_ms,
token_owner,
token_used,
) = row

return TokenLookupResult(
user_id=user_id,
is_guest=is_guest,
shadow_banned=shadow_banned,
token_id=token_id,
device_id=device_id,
valid_until_ms=valid_until_ms,
token_owner=token_owner,
# This field is nullable, ensure it comes out as a boolean
token_used=bool(token_used),
)

return None

@@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver."""

def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
txn.execute("SELECT COUNT(*) FROM users")
row = txn.fetchone()
assert row is not None
return row[0]

return await self.db_pool.runInteraction("count_users", _count_users)

@@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""

def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
row = txn.fetchone()
assert row is not None
return row[0]

return await self.db_pool.runInteraction("count_real_users", _count_users)

@@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])

res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
for (name,) in txn.fetchall():
self.set_expiration_date_for_user_txn(txn, name, use_delta=True)

await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)

rows = self.db_pool.cursor_to_dict(txn)
row = txn.fetchone()
assert row is not None

# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
return bool(rows[0]["approved"])
return bool(row[0])

return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)

rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()

if not rows:
return True, 0

rows_processed_nb = 0

for user in rows:
if not user["count_tokens"] and not user["count_threepids"]:
self.set_user_deactivated_status_txn(txn, user["name"], True)
for name, count_tokens, count_threepids in rows:
if not count_tokens and not count_threepids:
self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1

logger.info("Marked %d rows as deactivated", rows_processed_nb)

self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)

if batch_size > len(rows):


+ 19
- 23
synapse/storage/databases/main/room.py 查看文件

@@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):

def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Optional[int]]]:
) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)

return self.db_pool.cursor_to_dict(txn)
return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())

ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)

min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
min_lifetime, max_lifetime = ret

# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):

txn.execute(sql, args)

rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}

for row in rows:
rooms_dict[row["room_id"]] = RetentionPolicy(
min_lifetime=row["min_lifetime"],
max_lifetime=row["max_lifetime"],
rooms_dict = {
room_id: RetentionPolicy(
min_lifetime=min_lifetime,
max_lifetime=max_lifetime,
)
for room_id, min_lifetime, max_lifetime in txn
}

if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):

txn.execute(sql)

rows = self.db_pool.cursor_to_dict(txn)

# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
rooms_dict[row["room_id"]] = RetentionPolicy()
for (room_id,) in txn:
if room_id not in rooms_dict:
rooms_dict[room_id] = RetentionPolicy()

return rooms_dict

@@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)

rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()

if not rows:
return True

for row in rows:
if not row["json"]:
for room_id, event_id, json in rows:
if not json:
retention_policy = {}
else:
ev = db_to_json(row["json"])
ev = db_to_json(json)
retention_policy = ev["content"]

self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
"room_id": row["room_id"],
"event_id": row["event_id"],
"room_id": room_id,
"event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))

self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
txn, "insert_room_retention", {"room_id": rows[-1][0]}
)

if batch_size > len(rows):


+ 4
- 6
synapse/storage/databases/main/roommember.py 查看文件

@@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):

txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return 0

min_stream_id = rows[-1]["stream_ordering"]
min_stream_id = rows[-1][0]

to_update = []
for row in rows:
event_id = row["event_id"]
room_id = row["room_id"]
for _, event_id, room_id, json in rows:
try:
event_json = db_to_json(row["json"])
event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue


+ 11
- 9
synapse/storage/databases/main/search.py 查看文件

@@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return 0

min_stream_id = rows[-1]["stream_ordering"]
min_stream_id = rows[-1][0]

event_search_rows = []
for row in rows:
for (
stream_ordering,
event_id,
room_id,
etype,
json,
origin_server_ts,
) in rows:
try:
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
event_json = db_to_json(row["json"])
event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue


+ 33
- 11
synapse/storage/databases/main/task_scheduler.py 查看文件

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast

from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]


class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)

@staticmethod
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
row["status"] = TaskStatus(row["status"])
if row["params"] is not None:
row["params"] = db_to_json(row["params"])
if row["result"] is not None:
row["result"] = db_to_json(row["result"])
return ScheduledTask(**row)
def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
task_id, action, status, timestamp, resource_id, params, result, error = row
return ScheduledTask(
id=task_id,
action=action,
status=TaskStatus(status),
timestamp=timestamp,
resource_id=resource_id,
params=db_to_json(params) if params is not None else None,
result=db_to_json(result) if result is not None else None,
error=error,
)

async def get_scheduled_tasks(
self,
@@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""

def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)

txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn)
return cast(List[ScheduledTaskRow], txn.fetchall())

rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)

return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)

async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.


Loading…
取消
儲存