25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.
 
 
 
 
 
 

2112 satır
71 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2017-2018 New Vector Ltd
  3. # Copyright 2019 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. import time
  18. from sys import intern
  19. from time import monotonic as monotonic_time
  20. from typing import (
  21. Any,
  22. Callable,
  23. Collection,
  24. Dict,
  25. Iterable,
  26. Iterator,
  27. List,
  28. Optional,
  29. Tuple,
  30. TypeVar,
  31. cast,
  32. overload,
  33. )
  34. import attr
  35. from prometheus_client import Histogram
  36. from typing_extensions import Literal
  37. from twisted.enterprise import adbapi
  38. from synapse.api.errors import StoreError
  39. from synapse.config.database import DatabaseConnectionConfig
  40. from synapse.logging import opentracing
  41. from synapse.logging.context import (
  42. LoggingContext,
  43. current_context,
  44. make_deferred_yieldable,
  45. )
  46. from synapse.metrics.background_process_metrics import run_as_background_process
  47. from synapse.storage.background_updates import BackgroundUpdater
  48. from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
  49. from synapse.storage.types import Connection, Cursor
  50. # python 3 does not have a maximum int value
  51. MAX_TXN_ID = 2 ** 63 - 1
  52. logger = logging.getLogger(__name__)
  53. sql_logger = logging.getLogger("synapse.storage.SQL")
  54. transaction_logger = logging.getLogger("synapse.storage.txn")
  55. perf_logger = logging.getLogger("synapse.storage.TIME")
  56. sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
  57. sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
  58. sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
  59. # Unique indexes which have been added in background updates. Maps from table name
  60. # to the name of the background update which added the unique index to that table.
  61. #
  62. # This is used by the upsert logic to figure out which tables are safe to do a proper
  63. # UPSERT on: until the relevant background update has completed, we
  64. # have to emulate an upsert by locking the table.
  65. #
  66. UNIQUE_INDEX_BACKGROUND_UPDATES = {
  67. "user_ips": "user_ips_device_unique_index",
  68. "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
  69. "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
  70. "event_search": "event_search_event_id_idx",
  71. }
  72. def make_pool(
  73. reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
  74. ) -> adbapi.ConnectionPool:
  75. """Get the connection pool for the database."""
  76. # By default enable `cp_reconnect`. We need to fiddle with db_args in case
  77. # someone has explicitly set `cp_reconnect`.
  78. db_args = dict(db_config.config.get("args", {}))
  79. db_args.setdefault("cp_reconnect", True)
  80. def _on_new_connection(conn):
  81. # Ensure we have a logging context so we can correctly track queries,
  82. # etc.
  83. with LoggingContext("db.on_new_connection"):
  84. engine.on_new_connection(
  85. LoggingDatabaseConnection(conn, engine, "on_new_connection")
  86. )
  87. return adbapi.ConnectionPool(
  88. db_config.config["name"],
  89. cp_reactor=reactor,
  90. cp_openfun=_on_new_connection,
  91. **db_args,
  92. )
  93. def make_conn(
  94. db_config: DatabaseConnectionConfig,
  95. engine: BaseDatabaseEngine,
  96. default_txn_name: str,
  97. ) -> "LoggingDatabaseConnection":
  98. """Make a new connection to the database and return it.
  99. Returns:
  100. Connection
  101. """
  102. db_params = {
  103. k: v
  104. for k, v in db_config.config.get("args", {}).items()
  105. if not k.startswith("cp_")
  106. }
  107. native_db_conn = engine.module.connect(**db_params)
  108. db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
  109. engine.on_new_connection(db_conn)
  110. return db_conn
  111. @attr.s(slots=True)
  112. class LoggingDatabaseConnection:
  113. """A wrapper around a database connection that returns `LoggingTransaction`
  114. as its cursor class.
  115. This is mainly used on startup to ensure that queries get logged correctly
  116. """
  117. conn = attr.ib(type=Connection)
  118. engine = attr.ib(type=BaseDatabaseEngine)
  119. default_txn_name = attr.ib(type=str)
  120. def cursor(
  121. self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
  122. ) -> "LoggingTransaction":
  123. if not txn_name:
  124. txn_name = self.default_txn_name
  125. return LoggingTransaction(
  126. self.conn.cursor(),
  127. name=txn_name,
  128. database_engine=self.engine,
  129. after_callbacks=after_callbacks,
  130. exception_callbacks=exception_callbacks,
  131. )
  132. def close(self) -> None:
  133. self.conn.close()
  134. def commit(self) -> None:
  135. self.conn.commit()
  136. def rollback(self) -> None:
  137. self.conn.rollback()
  138. def __enter__(self) -> "Connection":
  139. self.conn.__enter__()
  140. return self
  141. def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
  142. return self.conn.__exit__(exc_type, exc_value, traceback)
  143. # Proxy through any unknown lookups to the DB conn class.
  144. def __getattr__(self, name):
  145. return getattr(self.conn, name)
  146. # The type of entry which goes on our after_callbacks and exception_callbacks lists.
  147. _CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
  148. R = TypeVar("R")
  149. class LoggingTransaction:
  150. """An object that almost-transparently proxies for the 'txn' object
  151. passed to the constructor. Adds logging and metrics to the .execute()
  152. method.
  153. Args:
  154. txn: The database transaction object to wrap.
  155. name: The name of this transactions for logging.
  156. database_engine
  157. after_callbacks: A list that callbacks will be appended to
  158. that have been added by `call_after` which should be run on
  159. successful completion of the transaction. None indicates that no
  160. callbacks should be allowed to be scheduled to run.
  161. exception_callbacks: A list that callbacks will be appended
  162. to that have been added by `call_on_exception` which should be run
  163. if transaction ends with an error. None indicates that no callbacks
  164. should be allowed to be scheduled to run.
  165. """
  166. __slots__ = [
  167. "txn",
  168. "name",
  169. "database_engine",
  170. "after_callbacks",
  171. "exception_callbacks",
  172. ]
  173. def __init__(
  174. self,
  175. txn: Cursor,
  176. name: str,
  177. database_engine: BaseDatabaseEngine,
  178. after_callbacks: Optional[List[_CallbackListEntry]] = None,
  179. exception_callbacks: Optional[List[_CallbackListEntry]] = None,
  180. ):
  181. self.txn = txn
  182. self.name = name
  183. self.database_engine = database_engine
  184. self.after_callbacks = after_callbacks
  185. self.exception_callbacks = exception_callbacks
  186. def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
  187. """Call the given callback on the main twisted thread after the
  188. transaction has finished. Used to invalidate the caches on the
  189. correct thread.
  190. """
  191. # if self.after_callbacks is None, that means that whatever constructed the
  192. # LoggingTransaction isn't expecting there to be any callbacks; assert that
  193. # is not the case.
  194. assert self.after_callbacks is not None
  195. self.after_callbacks.append((callback, args, kwargs))
  196. def call_on_exception(
  197. self, callback: Callable[..., None], *args: Any, **kwargs: Any
  198. ):
  199. # if self.exception_callbacks is None, that means that whatever constructed the
  200. # LoggingTransaction isn't expecting there to be any callbacks; assert that
  201. # is not the case.
  202. assert self.exception_callbacks is not None
  203. self.exception_callbacks.append((callback, args, kwargs))
  204. def fetchone(self) -> Optional[Tuple]:
  205. return self.txn.fetchone()
  206. def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
  207. return self.txn.fetchmany(size=size)
  208. def fetchall(self) -> List[Tuple]:
  209. return self.txn.fetchall()
  210. def __iter__(self) -> Iterator[Tuple]:
  211. return self.txn.__iter__()
  212. @property
  213. def rowcount(self) -> int:
  214. return self.txn.rowcount
  215. @property
  216. def description(self) -> Any:
  217. return self.txn.description
  218. def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
  219. """Similar to `executemany`, except `txn.rowcount` will not be correct
  220. afterwards.
  221. More efficient than `executemany` on PostgreSQL
  222. """
  223. if isinstance(self.database_engine, PostgresEngine):
  224. from psycopg2.extras import execute_batch # type: ignore
  225. self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
  226. else:
  227. self.executemany(sql, args)
  228. def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
  229. """Corresponds to psycopg2.extras.execute_values. Only available when
  230. using postgres.
  231. Always sets fetch=True when caling `execute_values`, so will return the
  232. results.
  233. """
  234. assert isinstance(self.database_engine, PostgresEngine)
  235. from psycopg2.extras import execute_values # type: ignore
  236. return self._do_execute(
  237. lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
  238. )
  239. def execute(self, sql: str, *args: Any) -> None:
  240. self._do_execute(self.txn.execute, sql, *args)
  241. def executemany(self, sql: str, *args: Any) -> None:
  242. self._do_execute(self.txn.executemany, sql, *args)
  243. def _make_sql_one_line(self, sql: str) -> str:
  244. "Strip newlines out of SQL so that the loggers in the DB are on one line"
  245. return " ".join(line.strip() for line in sql.splitlines() if line.strip())
  246. def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
  247. sql = self._make_sql_one_line(sql)
  248. # TODO(paul): Maybe use 'info' and 'debug' for values?
  249. sql_logger.debug("[SQL] {%s} %s", self.name, sql)
  250. sql = self.database_engine.convert_param_style(sql)
  251. if args:
  252. try:
  253. sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
  254. except Exception:
  255. # Don't let logging failures stop SQL from working
  256. pass
  257. start = time.time()
  258. try:
  259. with opentracing.start_active_span(
  260. "db.query",
  261. tags={
  262. opentracing.tags.DATABASE_TYPE: "sql",
  263. opentracing.tags.DATABASE_STATEMENT: sql,
  264. },
  265. ):
  266. return func(sql, *args)
  267. except Exception as e:
  268. sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
  269. raise
  270. finally:
  271. secs = time.time() - start
  272. sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
  273. sql_query_timer.labels(sql.split()[0]).observe(secs)
  274. def close(self) -> None:
  275. self.txn.close()
  276. def __enter__(self) -> "LoggingTransaction":
  277. return self
  278. def __exit__(self, exc_type, exc_value, traceback):
  279. self.close()
  280. class PerformanceCounters:
  281. def __init__(self):
  282. self.current_counters = {}
  283. self.previous_counters = {}
  284. def update(self, key: str, duration_secs: float) -> None:
  285. count, cum_time = self.current_counters.get(key, (0, 0))
  286. count += 1
  287. cum_time += duration_secs
  288. self.current_counters[key] = (count, cum_time)
  289. def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
  290. counters = []
  291. for name, (count, cum_time) in self.current_counters.items():
  292. prev_count, prev_time = self.previous_counters.get(name, (0, 0))
  293. counters.append(
  294. (
  295. (cum_time - prev_time) / interval_duration_secs,
  296. count - prev_count,
  297. name,
  298. )
  299. )
  300. self.previous_counters = dict(self.current_counters)
  301. counters.sort(reverse=True)
  302. top_n_counters = ", ".join(
  303. "%s(%d): %.3f%%" % (name, count, 100 * ratio)
  304. for ratio, count, name in counters[:limit]
  305. )
  306. return top_n_counters
  307. class DatabasePool:
  308. """Wraps a single physical database and connection pool.
  309. A single database may be used by multiple data stores.
  310. """
  311. _TXN_ID = 0
  312. def __init__(
  313. self,
  314. hs,
  315. database_config: DatabaseConnectionConfig,
  316. engine: BaseDatabaseEngine,
  317. ):
  318. self.hs = hs
  319. self._clock = hs.get_clock()
  320. self._database_config = database_config
  321. self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
  322. self.updates = BackgroundUpdater(hs, self)
  323. self._previous_txn_total_time = 0.0
  324. self._current_txn_total_time = 0.0
  325. self._previous_loop_ts = 0.0
  326. # TODO(paul): These can eventually be removed once the metrics code
  327. # is running in mainline, and we have some nice monitoring frontends
  328. # to watch it
  329. self._txn_perf_counters = PerformanceCounters()
  330. self.engine = engine
  331. # A set of tables that are not safe to use native upserts in.
  332. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
  333. # We add the user_directory_search table to the blacklist on SQLite
  334. # because the existing search table does not have an index, making it
  335. # unsafe to use native upserts.
  336. if isinstance(self.engine, Sqlite3Engine):
  337. self._unsafe_to_upsert_tables.add("user_directory_search")
  338. if self.engine.can_native_upsert:
  339. # Check ASAP (and then later, every 1s) to see if we have finished
  340. # background updates of tables that aren't safe to update.
  341. self._clock.call_later(
  342. 0.0,
  343. run_as_background_process,
  344. "upsert_safety_check",
  345. self._check_safe_to_upsert,
  346. )
  347. def is_running(self) -> bool:
  348. """Is the database pool currently running"""
  349. return self._db_pool.running
  350. async def _check_safe_to_upsert(self) -> None:
  351. """
  352. Is it safe to use native UPSERT?
  353. If there are background updates, we will need to wait, as they may be
  354. the addition of indexes that set the UNIQUE constraint that we require.
  355. If the background updates have not completed, wait 15 sec and check again.
  356. """
  357. updates = await self.simple_select_list(
  358. "background_updates",
  359. keyvalues=None,
  360. retcols=["update_name"],
  361. desc="check_background_updates",
  362. )
  363. updates = [x["update_name"] for x in updates]
  364. for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
  365. if update_name not in updates:
  366. logger.debug("Now safe to upsert in %s", table)
  367. self._unsafe_to_upsert_tables.discard(table)
  368. # If there's any updates still running, reschedule to run.
  369. if updates:
  370. self._clock.call_later(
  371. 15.0,
  372. run_as_background_process,
  373. "upsert_safety_check",
  374. self._check_safe_to_upsert,
  375. )
  376. def start_profiling(self) -> None:
  377. self._previous_loop_ts = monotonic_time()
  378. def loop():
  379. curr = self._current_txn_total_time
  380. prev = self._previous_txn_total_time
  381. self._previous_txn_total_time = curr
  382. time_now = monotonic_time()
  383. time_then = self._previous_loop_ts
  384. self._previous_loop_ts = time_now
  385. duration = time_now - time_then
  386. ratio = (curr - prev) / duration
  387. top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
  388. perf_logger.debug(
  389. "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
  390. )
  391. self._clock.looping_call(loop, 10000)
  392. def new_transaction(
  393. self,
  394. conn: LoggingDatabaseConnection,
  395. desc: str,
  396. after_callbacks: List[_CallbackListEntry],
  397. exception_callbacks: List[_CallbackListEntry],
  398. func: Callable[..., R],
  399. *args: Any,
  400. **kwargs: Any,
  401. ) -> R:
  402. """Start a new database transaction with the given connection.
  403. Note: The given func may be called multiple times under certain
  404. failure modes. This is normally fine when in a standard transaction,
  405. but care must be taken if the connection is in `autocommit` mode that
  406. the function will correctly handle being aborted and retried half way
  407. through its execution.
  408. Args:
  409. conn
  410. desc
  411. after_callbacks
  412. exception_callbacks
  413. func
  414. *args
  415. **kwargs
  416. """
  417. start = monotonic_time()
  418. txn_id = self._TXN_ID
  419. # We don't really need these to be unique, so lets stop it from
  420. # growing really large.
  421. self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
  422. name = "%s-%x" % (desc, txn_id)
  423. transaction_logger.debug("[TXN START] {%s}", name)
  424. try:
  425. i = 0
  426. N = 5
  427. while True:
  428. cursor = conn.cursor(
  429. txn_name=name,
  430. after_callbacks=after_callbacks,
  431. exception_callbacks=exception_callbacks,
  432. )
  433. try:
  434. with opentracing.start_active_span(
  435. "db.txn",
  436. tags={
  437. opentracing.SynapseTags.DB_TXN_DESC: desc,
  438. opentracing.SynapseTags.DB_TXN_ID: name,
  439. },
  440. ):
  441. r = func(cursor, *args, **kwargs)
  442. opentracing.log_kv({"message": "commit"})
  443. conn.commit()
  444. return r
  445. except self.engine.module.OperationalError as e:
  446. # This can happen if the database disappears mid
  447. # transaction.
  448. transaction_logger.warning(
  449. "[TXN OPERROR] {%s} %s %d/%d",
  450. name,
  451. e,
  452. i,
  453. N,
  454. )
  455. if i < N:
  456. i += 1
  457. try:
  458. with opentracing.start_active_span("db.rollback"):
  459. conn.rollback()
  460. except self.engine.module.Error as e1:
  461. transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
  462. continue
  463. raise
  464. except self.engine.module.DatabaseError as e:
  465. if self.engine.is_deadlock(e):
  466. transaction_logger.warning(
  467. "[TXN DEADLOCK] {%s} %d/%d", name, i, N
  468. )
  469. if i < N:
  470. i += 1
  471. try:
  472. with opentracing.start_active_span("db.rollback"):
  473. conn.rollback()
  474. except self.engine.module.Error as e1:
  475. transaction_logger.warning(
  476. "[TXN EROLL] {%s} %s",
  477. name,
  478. e1,
  479. )
  480. continue
  481. raise
  482. finally:
  483. # we're either about to retry with a new cursor, or we're about to
  484. # release the connection. Once we release the connection, it could
  485. # get used for another query, which might do a conn.rollback().
  486. #
  487. # In the latter case, even though that probably wouldn't affect the
  488. # results of this transaction, python's sqlite will reset all
  489. # statements on the connection [1], which will make our cursor
  490. # invalid [2].
  491. #
  492. # In any case, continuing to read rows after commit()ing seems
  493. # dubious from the PoV of ACID transactional semantics
  494. # (sqlite explicitly says that once you commit, you may see rows
  495. # from subsequent updates.)
  496. #
  497. # In psycopg2, cursors are essentially a client-side fabrication -
  498. # all the data is transferred to the client side when the statement
  499. # finishes executing - so in theory we could go on streaming results
  500. # from the cursor, but attempting to do so would make us
  501. # incompatible with sqlite, so let's make sure we're not doing that
  502. # by closing the cursor.
  503. #
  504. # (*named* cursors in psycopg2 are different and are proper server-
  505. # side things, but (a) we don't use them and (b) they are implicitly
  506. # closed by ending the transaction anyway.)
  507. #
  508. # In short, if we haven't finished with the cursor yet, that's a
  509. # problem waiting to bite us.
  510. #
  511. # TL;DR: we're done with the cursor, so we can close it.
  512. #
  513. # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
  514. # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
  515. cursor.close()
  516. except Exception as e:
  517. transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
  518. raise
  519. finally:
  520. end = monotonic_time()
  521. duration = end - start
  522. current_context().add_database_transaction(duration)
  523. transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
  524. self._current_txn_total_time += duration
  525. self._txn_perf_counters.update(desc, duration)
  526. sql_txn_timer.labels(desc).observe(duration)
  527. async def runInteraction(
  528. self,
  529. desc: str,
  530. func: Callable[..., R],
  531. *args: Any,
  532. db_autocommit: bool = False,
  533. **kwargs: Any,
  534. ) -> R:
  535. """Starts a transaction on the database and runs a given function
  536. Arguments:
  537. desc: description of the transaction, for logging and metrics
  538. func: callback function, which will be called with a
  539. database transaction (twisted.enterprise.adbapi.Transaction) as
  540. its first argument, followed by `args` and `kwargs`.
  541. db_autocommit: Whether to run the function in "autocommit" mode,
  542. i.e. outside of a transaction. This is useful for transactions
  543. that are only a single query.
  544. Currently, this is only implemented for Postgres. SQLite will still
  545. run the function inside a transaction.
  546. WARNING: This means that if func fails half way through then
  547. the changes will *not* be rolled back. `func` may also get
  548. called multiple times if the transaction is retried, so must
  549. correctly handle that case.
  550. args: positional args to pass to `func`
  551. kwargs: named args to pass to `func`
  552. Returns:
  553. The result of func
  554. """
  555. after_callbacks = [] # type: List[_CallbackListEntry]
  556. exception_callbacks = [] # type: List[_CallbackListEntry]
  557. if not current_context():
  558. logger.warning("Starting db txn '%s' from sentinel context", desc)
  559. try:
  560. with opentracing.start_active_span(f"db.{desc}"):
  561. result = await self.runWithConnection(
  562. self.new_transaction,
  563. desc,
  564. after_callbacks,
  565. exception_callbacks,
  566. func,
  567. *args,
  568. db_autocommit=db_autocommit,
  569. **kwargs,
  570. )
  571. for after_callback, after_args, after_kwargs in after_callbacks:
  572. after_callback(*after_args, **after_kwargs)
  573. except Exception:
  574. for after_callback, after_args, after_kwargs in exception_callbacks:
  575. after_callback(*after_args, **after_kwargs)
  576. raise
  577. return cast(R, result)
  578. async def runWithConnection(
  579. self,
  580. func: Callable[..., R],
  581. *args: Any,
  582. db_autocommit: bool = False,
  583. **kwargs: Any,
  584. ) -> R:
  585. """Wraps the .runWithConnection() method on the underlying db_pool.
  586. Arguments:
  587. func: callback function, which will be called with a
  588. database connection (twisted.enterprise.adbapi.Connection) as
  589. its first argument, followed by `args` and `kwargs`.
  590. args: positional args to pass to `func`
  591. db_autocommit: Whether to run the function in "autocommit" mode,
  592. i.e. outside of a transaction. This is useful for transaction
  593. that are only a single query. Currently only affects postgres.
  594. kwargs: named args to pass to `func`
  595. Returns:
  596. The result of func
  597. """
  598. curr_context = current_context()
  599. if not curr_context:
  600. logger.warning(
  601. "Starting db connection from sentinel context: metrics will be lost"
  602. )
  603. parent_context = None
  604. else:
  605. assert isinstance(curr_context, LoggingContext)
  606. parent_context = curr_context
  607. start_time = monotonic_time()
  608. def inner_func(conn, *args, **kwargs):
  609. # We shouldn't be in a transaction. If we are then something
  610. # somewhere hasn't committed after doing work. (This is likely only
  611. # possible during startup, as `run*` will ensure changes are
  612. # committed/rolled back before putting the connection back in the
  613. # pool).
  614. assert not self.engine.in_transaction(conn)
  615. with LoggingContext(
  616. str(curr_context), parent_context=parent_context
  617. ) as context:
  618. with opentracing.start_active_span(
  619. operation_name="db.connection",
  620. ):
  621. sched_duration_sec = monotonic_time() - start_time
  622. sql_scheduling_timer.observe(sched_duration_sec)
  623. context.add_database_scheduled(sched_duration_sec)
  624. if self.engine.is_connection_closed(conn):
  625. logger.debug("Reconnecting closed database connection")
  626. conn.reconnect()
  627. opentracing.log_kv({"message": "reconnected"})
  628. try:
  629. if db_autocommit:
  630. self.engine.attempt_to_set_autocommit(conn, True)
  631. db_conn = LoggingDatabaseConnection(
  632. conn, self.engine, "runWithConnection"
  633. )
  634. return func(db_conn, *args, **kwargs)
  635. finally:
  636. if db_autocommit:
  637. self.engine.attempt_to_set_autocommit(conn, False)
  638. return await make_deferred_yieldable(
  639. self._db_pool.runWithConnection(inner_func, *args, **kwargs)
  640. )
  641. @staticmethod
  642. def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
  643. """Converts a SQL cursor into an list of dicts.
  644. Args:
  645. cursor: The DBAPI cursor which has executed a query.
  646. Returns:
  647. A list of dicts where the key is the column header.
  648. """
  649. assert cursor.description is not None, "cursor.description was None"
  650. col_headers = [intern(str(column[0])) for column in cursor.description]
  651. results = [dict(zip(col_headers, row)) for row in cursor]
  652. return results
  653. @overload
  654. async def execute(
  655. self, desc: str, decoder: Literal[None], query: str, *args: Any
  656. ) -> List[Tuple[Any, ...]]:
  657. ...
  658. @overload
  659. async def execute(
  660. self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
  661. ) -> R:
  662. ...
  663. async def execute(
  664. self,
  665. desc: str,
  666. decoder: Optional[Callable[[Cursor], R]],
  667. query: str,
  668. *args: Any,
  669. ) -> R:
  670. """Runs a single query for a result set.
  671. Args:
  672. desc: description of the transaction, for logging and metrics
  673. decoder - The function which can resolve the cursor results to
  674. something meaningful.
  675. query - The query string to execute
  676. *args - Query args.
  677. Returns:
  678. The result of decoder(results)
  679. """
  680. def interaction(txn):
  681. txn.execute(query, args)
  682. if decoder:
  683. return decoder(txn)
  684. else:
  685. return txn.fetchall()
  686. return await self.runInteraction(desc, interaction)
  687. # "Simple" SQL API methods that operate on a single table with no JOINs,
  688. # no complex WHERE clauses, just a dict of values for columns.
  689. async def simple_insert(
  690. self,
  691. table: str,
  692. values: Dict[str, Any],
  693. or_ignore: bool = False,
  694. desc: str = "simple_insert",
  695. ) -> bool:
  696. """Executes an INSERT query on the named table.
  697. Args:
  698. table: string giving the table name
  699. values: dict of new column names and values for them
  700. or_ignore: bool stating whether an exception should be raised
  701. when a conflicting row already exists. If True, False will be
  702. returned by the function instead
  703. desc: description of the transaction, for logging and metrics
  704. Returns:
  705. Whether the row was inserted or not. Only useful when `or_ignore` is True
  706. """
  707. try:
  708. await self.runInteraction(desc, self.simple_insert_txn, table, values)
  709. except self.engine.module.IntegrityError:
  710. # We have to do or_ignore flag at this layer, since we can't reuse
  711. # a cursor after we receive an error from the db.
  712. if not or_ignore:
  713. raise
  714. return False
  715. return True
  716. @staticmethod
  717. def simple_insert_txn(
  718. txn: LoggingTransaction, table: str, values: Dict[str, Any]
  719. ) -> None:
  720. keys, vals = zip(*values.items())
  721. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  722. table,
  723. ", ".join(k for k in keys),
  724. ", ".join("?" for _ in keys),
  725. )
  726. txn.execute(sql, vals)
  727. async def simple_insert_many(
  728. self, table: str, values: List[Dict[str, Any]], desc: str
  729. ) -> None:
  730. """Executes an INSERT query on the named table.
  731. Args:
  732. table: string giving the table name
  733. values: dict of new column names and values for them
  734. desc: description of the transaction, for logging and metrics
  735. """
  736. await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
  737. @staticmethod
  738. def simple_insert_many_txn(
  739. txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
  740. ) -> None:
  741. """Executes an INSERT query on the named table.
  742. Args:
  743. txn: The transaction to use.
  744. table: string giving the table name
  745. values: dict of new column names and values for them
  746. """
  747. if not values:
  748. return
  749. # This is a *slight* abomination to get a list of tuples of key names
  750. # and a list of tuples of value names.
  751. #
  752. # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
  753. # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
  754. #
  755. # The sort is to ensure that we don't rely on dictionary iteration
  756. # order.
  757. keys, vals = zip(
  758. *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
  759. )
  760. for k in keys:
  761. if k != keys[0]:
  762. raise RuntimeError("All items must have the same keys")
  763. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  764. table,
  765. ", ".join(k for k in keys[0]),
  766. ", ".join("?" for _ in keys[0]),
  767. )
  768. txn.execute_batch(sql, vals)
  769. async def simple_upsert(
  770. self,
  771. table: str,
  772. keyvalues: Dict[str, Any],
  773. values: Dict[str, Any],
  774. insertion_values: Optional[Dict[str, Any]] = None,
  775. desc: str = "simple_upsert",
  776. lock: bool = True,
  777. ) -> Optional[bool]:
  778. """
  779. `lock` should generally be set to True (the default), but can be set
  780. to False if either of the following are true:
  781. * there is a UNIQUE INDEX on the key columns. In this case a conflict
  782. will cause an IntegrityError in which case this function will retry
  783. the update.
  784. * we somehow know that we are the only thread which will be updating
  785. this table.
  786. Args:
  787. table: The table to upsert into
  788. keyvalues: The unique key columns and their new values
  789. values: The nonunique columns and their new values
  790. insertion_values: additional key/values to use only when inserting
  791. desc: description of the transaction, for logging and metrics
  792. lock: True to lock the table when doing the upsert.
  793. Returns:
  794. Native upserts always return None. Emulated upserts return True if a
  795. new entry was created, False if an existing one was updated.
  796. """
  797. insertion_values = insertion_values or {}
  798. attempts = 0
  799. while True:
  800. try:
  801. # We can autocommit if we are going to use native upserts
  802. autocommit = (
  803. self.engine.can_native_upsert
  804. and table not in self._unsafe_to_upsert_tables
  805. )
  806. return await self.runInteraction(
  807. desc,
  808. self.simple_upsert_txn,
  809. table,
  810. keyvalues,
  811. values,
  812. insertion_values,
  813. lock=lock,
  814. db_autocommit=autocommit,
  815. )
  816. except self.engine.module.IntegrityError as e:
  817. attempts += 1
  818. if attempts >= 5:
  819. # don't retry forever, because things other than races
  820. # can cause IntegrityErrors
  821. raise
  822. # presumably we raced with another transaction: let's retry.
  823. logger.warning(
  824. "IntegrityError when upserting into %s; retrying: %s", table, e
  825. )
  826. def simple_upsert_txn(
  827. self,
  828. txn: LoggingTransaction,
  829. table: str,
  830. keyvalues: Dict[str, Any],
  831. values: Dict[str, Any],
  832. insertion_values: Optional[Dict[str, Any]] = None,
  833. lock: bool = True,
  834. ) -> Optional[bool]:
  835. """
  836. Pick the UPSERT method which works best on the platform. Either the
  837. native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
  838. Args:
  839. txn: The transaction to use.
  840. table: The table to upsert into
  841. keyvalues: The unique key tables and their new values
  842. values: The nonunique columns and their new values
  843. insertion_values: additional key/values to use only when inserting
  844. lock: True to lock the table when doing the upsert.
  845. Returns:
  846. Native upserts always return None. Emulated upserts return True if a
  847. new entry was created, False if an existing one was updated.
  848. """
  849. insertion_values = insertion_values or {}
  850. if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
  851. self.simple_upsert_txn_native_upsert(
  852. txn, table, keyvalues, values, insertion_values=insertion_values
  853. )
  854. return None
  855. else:
  856. return self.simple_upsert_txn_emulated(
  857. txn,
  858. table,
  859. keyvalues,
  860. values,
  861. insertion_values=insertion_values,
  862. lock=lock,
  863. )
  864. def simple_upsert_txn_emulated(
  865. self,
  866. txn: LoggingTransaction,
  867. table: str,
  868. keyvalues: Dict[str, Any],
  869. values: Dict[str, Any],
  870. insertion_values: Optional[Dict[str, Any]] = None,
  871. lock: bool = True,
  872. ) -> bool:
  873. """
  874. Args:
  875. table: The table to upsert into
  876. keyvalues: The unique key tables and their new values
  877. values: The nonunique columns and their new values
  878. insertion_values: additional key/values to use only when inserting
  879. lock: True to lock the table when doing the upsert.
  880. Returns:
  881. Returns True if a new entry was created, False if an existing
  882. one was updated.
  883. """
  884. insertion_values = insertion_values or {}
  885. # We need to lock the table :(, unless we're *really* careful
  886. if lock:
  887. self.engine.lock_table(txn, table)
  888. def _getwhere(key):
  889. # If the value we're passing in is None (aka NULL), we need to use
  890. # IS, not =, as NULL = NULL equals NULL (False).
  891. if keyvalues[key] is None:
  892. return "%s IS ?" % (key,)
  893. else:
  894. return "%s = ?" % (key,)
  895. if not values:
  896. # If `values` is empty, then all of the values we care about are in
  897. # the unique key, so there is nothing to UPDATE. We can just do a
  898. # SELECT instead to see if it exists.
  899. sql = "SELECT 1 FROM %s WHERE %s" % (
  900. table,
  901. " AND ".join(_getwhere(k) for k in keyvalues),
  902. )
  903. sqlargs = list(keyvalues.values())
  904. txn.execute(sql, sqlargs)
  905. if txn.fetchall():
  906. # We have an existing record.
  907. return False
  908. else:
  909. # First try to update.
  910. sql = "UPDATE %s SET %s WHERE %s" % (
  911. table,
  912. ", ".join("%s = ?" % (k,) for k in values),
  913. " AND ".join(_getwhere(k) for k in keyvalues),
  914. )
  915. sqlargs = list(values.values()) + list(keyvalues.values())
  916. txn.execute(sql, sqlargs)
  917. if txn.rowcount > 0:
  918. # successfully updated at least one row.
  919. return False
  920. # We didn't find any existing rows, so insert a new one
  921. allvalues = {} # type: Dict[str, Any]
  922. allvalues.update(keyvalues)
  923. allvalues.update(values)
  924. allvalues.update(insertion_values)
  925. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  926. table,
  927. ", ".join(k for k in allvalues),
  928. ", ".join("?" for _ in allvalues),
  929. )
  930. txn.execute(sql, list(allvalues.values()))
  931. # successfully inserted
  932. return True
  933. def simple_upsert_txn_native_upsert(
  934. self,
  935. txn: LoggingTransaction,
  936. table: str,
  937. keyvalues: Dict[str, Any],
  938. values: Dict[str, Any],
  939. insertion_values: Optional[Dict[str, Any]] = None,
  940. ) -> None:
  941. """
  942. Use the native UPSERT functionality in recent PostgreSQL versions.
  943. Args:
  944. table: The table to upsert into
  945. keyvalues: The unique key tables and their new values
  946. values: The nonunique columns and their new values
  947. insertion_values: additional key/values to use only when inserting
  948. """
  949. allvalues = {} # type: Dict[str, Any]
  950. allvalues.update(keyvalues)
  951. allvalues.update(insertion_values or {})
  952. if not values:
  953. latter = "NOTHING"
  954. else:
  955. allvalues.update(values)
  956. latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
  957. sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
  958. table,
  959. ", ".join(k for k in allvalues),
  960. ", ".join("?" for _ in allvalues),
  961. ", ".join(k for k in keyvalues),
  962. latter,
  963. )
  964. txn.execute(sql, list(allvalues.values()))
  965. async def simple_upsert_many(
  966. self,
  967. table: str,
  968. key_names: Collection[str],
  969. key_values: Collection[Iterable[Any]],
  970. value_names: Collection[str],
  971. value_values: Iterable[Iterable[Any]],
  972. desc: str,
  973. ) -> None:
  974. """
  975. Upsert, many times.
  976. Args:
  977. table: The table to upsert into
  978. key_names: The key column names.
  979. key_values: A list of each row's key column values.
  980. value_names: The value column names
  981. value_values: A list of each row's value column values.
  982. Ignored if value_names is empty.
  983. """
  984. # We can autocommit if we are going to use native upserts
  985. autocommit = (
  986. self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
  987. )
  988. return await self.runInteraction(
  989. desc,
  990. self.simple_upsert_many_txn,
  991. table,
  992. key_names,
  993. key_values,
  994. value_names,
  995. value_values,
  996. db_autocommit=autocommit,
  997. )
  998. def simple_upsert_many_txn(
  999. self,
  1000. txn: LoggingTransaction,
  1001. table: str,
  1002. key_names: Collection[str],
  1003. key_values: Collection[Iterable[Any]],
  1004. value_names: Collection[str],
  1005. value_values: Iterable[Iterable[Any]],
  1006. ) -> None:
  1007. """
  1008. Upsert, many times.
  1009. Args:
  1010. table: The table to upsert into
  1011. key_names: The key column names.
  1012. key_values: A list of each row's key column values.
  1013. value_names: The value column names
  1014. value_values: A list of each row's value column values.
  1015. Ignored if value_names is empty.
  1016. """
  1017. if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
  1018. return self.simple_upsert_many_txn_native_upsert(
  1019. txn, table, key_names, key_values, value_names, value_values
  1020. )
  1021. else:
  1022. return self.simple_upsert_many_txn_emulated(
  1023. txn, table, key_names, key_values, value_names, value_values
  1024. )
  1025. def simple_upsert_many_txn_emulated(
  1026. self,
  1027. txn: LoggingTransaction,
  1028. table: str,
  1029. key_names: Iterable[str],
  1030. key_values: Collection[Iterable[Any]],
  1031. value_names: Collection[str],
  1032. value_values: Iterable[Iterable[Any]],
  1033. ) -> None:
  1034. """
  1035. Upsert, many times, but without native UPSERT support or batching.
  1036. Args:
  1037. table: The table to upsert into
  1038. key_names: The key column names.
  1039. key_values: A list of each row's key column values.
  1040. value_names: The value column names
  1041. value_values: A list of each row's value column values.
  1042. Ignored if value_names is empty.
  1043. """
  1044. # No value columns, therefore make a blank list so that the following
  1045. # zip() works correctly.
  1046. if not value_names:
  1047. value_values = [() for x in range(len(key_values))]
  1048. for keyv, valv in zip(key_values, value_values):
  1049. _keys = {x: y for x, y in zip(key_names, keyv)}
  1050. _vals = {x: y for x, y in zip(value_names, valv)}
  1051. self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
  1052. def simple_upsert_many_txn_native_upsert(
  1053. self,
  1054. txn: LoggingTransaction,
  1055. table: str,
  1056. key_names: Collection[str],
  1057. key_values: Collection[Iterable[Any]],
  1058. value_names: Collection[str],
  1059. value_values: Iterable[Iterable[Any]],
  1060. ) -> None:
  1061. """
  1062. Upsert, many times, using batching where possible.
  1063. Args:
  1064. table: The table to upsert into
  1065. key_names: The key column names.
  1066. key_values: A list of each row's key column values.
  1067. value_names: The value column names
  1068. value_values: A list of each row's value column values.
  1069. Ignored if value_names is empty.
  1070. """
  1071. allnames = [] # type: List[str]
  1072. allnames.extend(key_names)
  1073. allnames.extend(value_names)
  1074. if not value_names:
  1075. # No value columns, therefore make a blank list so that the
  1076. # following zip() works correctly.
  1077. latter = "NOTHING"
  1078. value_values = [() for x in range(len(key_values))]
  1079. else:
  1080. latter = "UPDATE SET " + ", ".join(
  1081. k + "=EXCLUDED." + k for k in value_names
  1082. )
  1083. sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
  1084. table,
  1085. ", ".join(k for k in allnames),
  1086. ", ".join("?" for _ in allnames),
  1087. ", ".join(key_names),
  1088. latter,
  1089. )
  1090. args = []
  1091. for x, y in zip(key_values, value_values):
  1092. args.append(tuple(x) + tuple(y))
  1093. return txn.execute_batch(sql, args)
  1094. @overload
  1095. async def simple_select_one(
  1096. self,
  1097. table: str,
  1098. keyvalues: Dict[str, Any],
  1099. retcols: Iterable[str],
  1100. allow_none: Literal[False] = False,
  1101. desc: str = "simple_select_one",
  1102. ) -> Dict[str, Any]:
  1103. ...
  1104. @overload
  1105. async def simple_select_one(
  1106. self,
  1107. table: str,
  1108. keyvalues: Dict[str, Any],
  1109. retcols: Iterable[str],
  1110. allow_none: Literal[True] = True,
  1111. desc: str = "simple_select_one",
  1112. ) -> Optional[Dict[str, Any]]:
  1113. ...
  1114. async def simple_select_one(
  1115. self,
  1116. table: str,
  1117. keyvalues: Dict[str, Any],
  1118. retcols: Iterable[str],
  1119. allow_none: bool = False,
  1120. desc: str = "simple_select_one",
  1121. ) -> Optional[Dict[str, Any]]:
  1122. """Executes a SELECT query on the named table, which is expected to
  1123. return a single row, returning multiple columns from it.
  1124. Args:
  1125. table: string giving the table name
  1126. keyvalues: dict of column names and values to select the row with
  1127. retcols: list of strings giving the names of the columns to return
  1128. allow_none: If true, return None instead of failing if the SELECT
  1129. statement returns no rows
  1130. desc: description of the transaction, for logging and metrics
  1131. """
  1132. return await self.runInteraction(
  1133. desc,
  1134. self.simple_select_one_txn,
  1135. table,
  1136. keyvalues,
  1137. retcols,
  1138. allow_none,
  1139. db_autocommit=True,
  1140. )
  1141. @overload
  1142. async def simple_select_one_onecol(
  1143. self,
  1144. table: str,
  1145. keyvalues: Dict[str, Any],
  1146. retcol: str,
  1147. allow_none: Literal[False] = False,
  1148. desc: str = "simple_select_one_onecol",
  1149. ) -> Any:
  1150. ...
  1151. @overload
  1152. async def simple_select_one_onecol(
  1153. self,
  1154. table: str,
  1155. keyvalues: Dict[str, Any],
  1156. retcol: str,
  1157. allow_none: Literal[True] = True,
  1158. desc: str = "simple_select_one_onecol",
  1159. ) -> Optional[Any]:
  1160. ...
  1161. async def simple_select_one_onecol(
  1162. self,
  1163. table: str,
  1164. keyvalues: Dict[str, Any],
  1165. retcol: str,
  1166. allow_none: bool = False,
  1167. desc: str = "simple_select_one_onecol",
  1168. ) -> Optional[Any]:
  1169. """Executes a SELECT query on the named table, which is expected to
  1170. return a single row, returning a single column from it.
  1171. Args:
  1172. table: string giving the table name
  1173. keyvalues: dict of column names and values to select the row with
  1174. retcol: string giving the name of the column to return
  1175. allow_none: If true, return None instead of failing if the SELECT
  1176. statement returns no rows
  1177. desc: description of the transaction, for logging and metrics
  1178. """
  1179. return await self.runInteraction(
  1180. desc,
  1181. self.simple_select_one_onecol_txn,
  1182. table,
  1183. keyvalues,
  1184. retcol,
  1185. allow_none=allow_none,
  1186. db_autocommit=True,
  1187. )
  1188. @overload
  1189. @classmethod
  1190. def simple_select_one_onecol_txn(
  1191. cls,
  1192. txn: LoggingTransaction,
  1193. table: str,
  1194. keyvalues: Dict[str, Any],
  1195. retcol: str,
  1196. allow_none: Literal[False] = False,
  1197. ) -> Any:
  1198. ...
  1199. @overload
  1200. @classmethod
  1201. def simple_select_one_onecol_txn(
  1202. cls,
  1203. txn: LoggingTransaction,
  1204. table: str,
  1205. keyvalues: Dict[str, Any],
  1206. retcol: str,
  1207. allow_none: Literal[True] = True,
  1208. ) -> Optional[Any]:
  1209. ...
  1210. @classmethod
  1211. def simple_select_one_onecol_txn(
  1212. cls,
  1213. txn: LoggingTransaction,
  1214. table: str,
  1215. keyvalues: Dict[str, Any],
  1216. retcol: str,
  1217. allow_none: bool = False,
  1218. ) -> Optional[Any]:
  1219. ret = cls.simple_select_onecol_txn(
  1220. txn, table=table, keyvalues=keyvalues, retcol=retcol
  1221. )
  1222. if ret:
  1223. return ret[0]
  1224. else:
  1225. if allow_none:
  1226. return None
  1227. else:
  1228. raise StoreError(404, "No row found")
  1229. @staticmethod
  1230. def simple_select_onecol_txn(
  1231. txn: LoggingTransaction,
  1232. table: str,
  1233. keyvalues: Dict[str, Any],
  1234. retcol: str,
  1235. ) -> List[Any]:
  1236. sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
  1237. if keyvalues:
  1238. sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
  1239. txn.execute(sql, list(keyvalues.values()))
  1240. else:
  1241. txn.execute(sql)
  1242. return [r[0] for r in txn]
  1243. async def simple_select_onecol(
  1244. self,
  1245. table: str,
  1246. keyvalues: Optional[Dict[str, Any]],
  1247. retcol: str,
  1248. desc: str = "simple_select_onecol",
  1249. ) -> List[Any]:
  1250. """Executes a SELECT query on the named table, which returns a list
  1251. comprising of the values of the named column from the selected rows.
  1252. Args:
  1253. table: table name
  1254. keyvalues: column names and values to select the rows with
  1255. retcol: column whos value we wish to retrieve.
  1256. desc: description of the transaction, for logging and metrics
  1257. Returns:
  1258. Results in a list
  1259. """
  1260. return await self.runInteraction(
  1261. desc,
  1262. self.simple_select_onecol_txn,
  1263. table,
  1264. keyvalues,
  1265. retcol,
  1266. db_autocommit=True,
  1267. )
  1268. async def simple_select_list(
  1269. self,
  1270. table: str,
  1271. keyvalues: Optional[Dict[str, Any]],
  1272. retcols: Iterable[str],
  1273. desc: str = "simple_select_list",
  1274. ) -> List[Dict[str, Any]]:
  1275. """Executes a SELECT query on the named table, which may return zero or
  1276. more rows, returning the result as a list of dicts.
  1277. Args:
  1278. table: the table name
  1279. keyvalues:
  1280. column names and values to select the rows with, or None to not
  1281. apply a WHERE clause.
  1282. retcols: the names of the columns to return
  1283. desc: description of the transaction, for logging and metrics
  1284. Returns:
  1285. A list of dictionaries.
  1286. """
  1287. return await self.runInteraction(
  1288. desc,
  1289. self.simple_select_list_txn,
  1290. table,
  1291. keyvalues,
  1292. retcols,
  1293. db_autocommit=True,
  1294. )
  1295. @classmethod
  1296. def simple_select_list_txn(
  1297. cls,
  1298. txn: LoggingTransaction,
  1299. table: str,
  1300. keyvalues: Optional[Dict[str, Any]],
  1301. retcols: Iterable[str],
  1302. ) -> List[Dict[str, Any]]:
  1303. """Executes a SELECT query on the named table, which may return zero or
  1304. more rows, returning the result as a list of dicts.
  1305. Args:
  1306. txn: Transaction object
  1307. table: the table name
  1308. keyvalues:
  1309. column names and values to select the rows with, or None to not
  1310. apply a WHERE clause.
  1311. retcols: the names of the columns to return
  1312. """
  1313. if keyvalues:
  1314. sql = "SELECT %s FROM %s WHERE %s" % (
  1315. ", ".join(retcols),
  1316. table,
  1317. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1318. )
  1319. txn.execute(sql, list(keyvalues.values()))
  1320. else:
  1321. sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
  1322. txn.execute(sql)
  1323. return cls.cursor_to_dict(txn)
  1324. async def simple_select_many_batch(
  1325. self,
  1326. table: str,
  1327. column: str,
  1328. iterable: Iterable[Any],
  1329. retcols: Iterable[str],
  1330. keyvalues: Optional[Dict[str, Any]] = None,
  1331. desc: str = "simple_select_many_batch",
  1332. batch_size: int = 100,
  1333. ) -> List[Any]:
  1334. """Executes a SELECT query on the named table, which may return zero or
  1335. more rows, returning the result as a list of dicts.
  1336. Filters rows by whether the value of `column` is in `iterable`.
  1337. Args:
  1338. table: string giving the table name
  1339. column: column name to test for inclusion against `iterable`
  1340. iterable: list
  1341. retcols: list of strings giving the names of the columns to return
  1342. keyvalues: dict of column names and values to select the rows with
  1343. desc: description of the transaction, for logging and metrics
  1344. batch_size: the number of rows for each select query
  1345. """
  1346. keyvalues = keyvalues or {}
  1347. results = [] # type: List[Dict[str, Any]]
  1348. if not iterable:
  1349. return results
  1350. # iterables can not be sliced, so convert it to a list first
  1351. it_list = list(iterable)
  1352. chunks = [
  1353. it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
  1354. ]
  1355. for chunk in chunks:
  1356. rows = await self.runInteraction(
  1357. desc,
  1358. self.simple_select_many_txn,
  1359. table,
  1360. column,
  1361. chunk,
  1362. keyvalues,
  1363. retcols,
  1364. db_autocommit=True,
  1365. )
  1366. results.extend(rows)
  1367. return results
  1368. @classmethod
  1369. def simple_select_many_txn(
  1370. cls,
  1371. txn: LoggingTransaction,
  1372. table: str,
  1373. column: str,
  1374. iterable: Iterable[Any],
  1375. keyvalues: Dict[str, Any],
  1376. retcols: Iterable[str],
  1377. ) -> List[Dict[str, Any]]:
  1378. """Executes a SELECT query on the named table, which may return zero or
  1379. more rows, returning the result as a list of dicts.
  1380. Filters rows by whether the value of `column` is in `iterable`.
  1381. Args:
  1382. txn: Transaction object
  1383. table: string giving the table name
  1384. column: column name to test for inclusion against `iterable`
  1385. iterable: list
  1386. keyvalues: dict of column names and values to select the rows with
  1387. retcols: list of strings giving the names of the columns to return
  1388. """
  1389. if not iterable:
  1390. return []
  1391. clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
  1392. clauses = [clause]
  1393. for key, value in keyvalues.items():
  1394. clauses.append("%s = ?" % (key,))
  1395. values.append(value)
  1396. sql = "SELECT %s FROM %s WHERE %s" % (
  1397. ", ".join(retcols),
  1398. table,
  1399. " AND ".join(clauses),
  1400. )
  1401. txn.execute(sql, values)
  1402. return cls.cursor_to_dict(txn)
  1403. async def simple_update(
  1404. self,
  1405. table: str,
  1406. keyvalues: Dict[str, Any],
  1407. updatevalues: Dict[str, Any],
  1408. desc: str,
  1409. ) -> int:
  1410. return await self.runInteraction(
  1411. desc, self.simple_update_txn, table, keyvalues, updatevalues
  1412. )
  1413. @staticmethod
  1414. def simple_update_txn(
  1415. txn: LoggingTransaction,
  1416. table: str,
  1417. keyvalues: Dict[str, Any],
  1418. updatevalues: Dict[str, Any],
  1419. ) -> int:
  1420. if keyvalues:
  1421. where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
  1422. else:
  1423. where = ""
  1424. update_sql = "UPDATE %s SET %s %s" % (
  1425. table,
  1426. ", ".join("%s = ?" % (k,) for k in updatevalues),
  1427. where,
  1428. )
  1429. txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
  1430. return txn.rowcount
  1431. async def simple_update_one(
  1432. self,
  1433. table: str,
  1434. keyvalues: Dict[str, Any],
  1435. updatevalues: Dict[str, Any],
  1436. desc: str = "simple_update_one",
  1437. ) -> None:
  1438. """Executes an UPDATE query on the named table, setting new values for
  1439. columns in a row matching the key values.
  1440. Args:
  1441. table: string giving the table name
  1442. keyvalues: dict of column names and values to select the row with
  1443. updatevalues: dict giving column names and values to update
  1444. desc: description of the transaction, for logging and metrics
  1445. """
  1446. await self.runInteraction(
  1447. desc,
  1448. self.simple_update_one_txn,
  1449. table,
  1450. keyvalues,
  1451. updatevalues,
  1452. db_autocommit=True,
  1453. )
  1454. @classmethod
  1455. def simple_update_one_txn(
  1456. cls,
  1457. txn: LoggingTransaction,
  1458. table: str,
  1459. keyvalues: Dict[str, Any],
  1460. updatevalues: Dict[str, Any],
  1461. ) -> None:
  1462. rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
  1463. if rowcount == 0:
  1464. raise StoreError(404, "No row found (%s)" % (table,))
  1465. if rowcount > 1:
  1466. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1467. # Ideally we could use the overload decorator here to specify that the
  1468. # return type is only optional if allow_none is True, but this does not work
  1469. # when you call a static method from an instance.
  1470. # See https://github.com/python/mypy/issues/7781
  1471. @staticmethod
  1472. def simple_select_one_txn(
  1473. txn: LoggingTransaction,
  1474. table: str,
  1475. keyvalues: Dict[str, Any],
  1476. retcols: Iterable[str],
  1477. allow_none: bool = False,
  1478. ) -> Optional[Dict[str, Any]]:
  1479. select_sql = "SELECT %s FROM %s WHERE %s" % (
  1480. ", ".join(retcols),
  1481. table,
  1482. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1483. )
  1484. txn.execute(select_sql, list(keyvalues.values()))
  1485. row = txn.fetchone()
  1486. if not row:
  1487. if allow_none:
  1488. return None
  1489. raise StoreError(404, "No row found (%s)" % (table,))
  1490. if txn.rowcount > 1:
  1491. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1492. return dict(zip(retcols, row))
  1493. async def simple_delete_one(
  1494. self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
  1495. ) -> None:
  1496. """Executes a DELETE query on the named table, expecting to delete a
  1497. single row.
  1498. Args:
  1499. table: string giving the table name
  1500. keyvalues: dict of column names and values to select the row with
  1501. desc: description of the transaction, for logging and metrics
  1502. """
  1503. await self.runInteraction(
  1504. desc,
  1505. self.simple_delete_one_txn,
  1506. table,
  1507. keyvalues,
  1508. db_autocommit=True,
  1509. )
  1510. @staticmethod
  1511. def simple_delete_one_txn(
  1512. txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
  1513. ) -> None:
  1514. """Executes a DELETE query on the named table, expecting to delete a
  1515. single row.
  1516. Args:
  1517. table: string giving the table name
  1518. keyvalues: dict of column names and values to select the row with
  1519. """
  1520. sql = "DELETE FROM %s WHERE %s" % (
  1521. table,
  1522. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1523. )
  1524. txn.execute(sql, list(keyvalues.values()))
  1525. if txn.rowcount == 0:
  1526. raise StoreError(404, "No row found (%s)" % (table,))
  1527. if txn.rowcount > 1:
  1528. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1529. async def simple_delete(
  1530. self, table: str, keyvalues: Dict[str, Any], desc: str
  1531. ) -> int:
  1532. """Executes a DELETE query on the named table.
  1533. Filters rows by the key-value pairs.
  1534. Args:
  1535. table: string giving the table name
  1536. keyvalues: dict of column names and values to select the row with
  1537. desc: description of the transaction, for logging and metrics
  1538. Returns:
  1539. The number of deleted rows.
  1540. """
  1541. return await self.runInteraction(
  1542. desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
  1543. )
  1544. @staticmethod
  1545. def simple_delete_txn(
  1546. txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
  1547. ) -> int:
  1548. """Executes a DELETE query on the named table.
  1549. Filters rows by the key-value pairs.
  1550. Args:
  1551. table: string giving the table name
  1552. keyvalues: dict of column names and values to select the row with
  1553. Returns:
  1554. The number of deleted rows.
  1555. """
  1556. sql = "DELETE FROM %s WHERE %s" % (
  1557. table,
  1558. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1559. )
  1560. txn.execute(sql, list(keyvalues.values()))
  1561. return txn.rowcount
  1562. async def simple_delete_many(
  1563. self,
  1564. table: str,
  1565. column: str,
  1566. iterable: Iterable[Any],
  1567. keyvalues: Dict[str, Any],
  1568. desc: str,
  1569. ) -> int:
  1570. """Executes a DELETE query on the named table.
  1571. Filters rows by if value of `column` is in `iterable`.
  1572. Args:
  1573. table: string giving the table name
  1574. column: column name to test for inclusion against `iterable`
  1575. iterable: list
  1576. keyvalues: dict of column names and values to select the rows with
  1577. desc: description of the transaction, for logging and metrics
  1578. Returns:
  1579. Number rows deleted
  1580. """
  1581. return await self.runInteraction(
  1582. desc,
  1583. self.simple_delete_many_txn,
  1584. table,
  1585. column,
  1586. iterable,
  1587. keyvalues,
  1588. db_autocommit=True,
  1589. )
  1590. @staticmethod
  1591. def simple_delete_many_txn(
  1592. txn: LoggingTransaction,
  1593. table: str,
  1594. column: str,
  1595. iterable: Iterable[Any],
  1596. keyvalues: Dict[str, Any],
  1597. ) -> int:
  1598. """Executes a DELETE query on the named table.
  1599. Filters rows by if value of `column` is in `iterable`.
  1600. Args:
  1601. txn: Transaction object
  1602. table: string giving the table name
  1603. column: column name to test for inclusion against `iterable`
  1604. iterable: list
  1605. keyvalues: dict of column names and values to select the rows with
  1606. Returns:
  1607. Number rows deleted
  1608. """
  1609. if not iterable:
  1610. return 0
  1611. sql = "DELETE FROM %s" % table
  1612. clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
  1613. clauses = [clause]
  1614. for key, value in keyvalues.items():
  1615. clauses.append("%s = ?" % (key,))
  1616. values.append(value)
  1617. if clauses:
  1618. sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
  1619. txn.execute(sql, values)
  1620. return txn.rowcount
  1621. def get_cache_dict(
  1622. self,
  1623. db_conn: LoggingDatabaseConnection,
  1624. table: str,
  1625. entity_column: str,
  1626. stream_column: str,
  1627. max_value: int,
  1628. limit: int = 100000,
  1629. ) -> Tuple[Dict[Any, int], int]:
  1630. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
  1631. # It doesn't really matter how many we get, the StreamChangeCache will
  1632. # do the right thing to ensure it respects the max size of cache.
  1633. sql = (
  1634. "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
  1635. " WHERE %(stream)s > ? - %(limit)s"
  1636. " GROUP BY %(entity)s"
  1637. ) % {
  1638. "table": table,
  1639. "entity": entity_column,
  1640. "stream": stream_column,
  1641. "limit": limit,
  1642. }
  1643. txn = db_conn.cursor(txn_name="get_cache_dict")
  1644. txn.execute(sql, (int(max_value),))
  1645. cache = {row[0]: int(row[1]) for row in txn}
  1646. txn.close()
  1647. if cache:
  1648. min_val = min(cache.values())
  1649. else:
  1650. min_val = max_value
  1651. return cache, min_val
  1652. @classmethod
  1653. def simple_select_list_paginate_txn(
  1654. cls,
  1655. txn: LoggingTransaction,
  1656. table: str,
  1657. orderby: str,
  1658. start: int,
  1659. limit: int,
  1660. retcols: Iterable[str],
  1661. filters: Optional[Dict[str, Any]] = None,
  1662. keyvalues: Optional[Dict[str, Any]] = None,
  1663. exclude_keyvalues: Optional[Dict[str, Any]] = None,
  1664. order_direction: str = "ASC",
  1665. ) -> List[Dict[str, Any]]:
  1666. """
  1667. Executes a SELECT query on the named table with start and limit,
  1668. of row numbers, which may return zero or number of rows from start to limit,
  1669. returning the result as a list of dicts.
  1670. Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
  1671. select attributes with exact matches. All constraints are joined together
  1672. using 'AND'.
  1673. Args:
  1674. txn: Transaction object
  1675. table: the table name
  1676. orderby: Column to order the results by.
  1677. start: Index to begin the query at.
  1678. limit: Number of results to return.
  1679. retcols: the names of the columns to return
  1680. filters:
  1681. column names and values to filter the rows with, or None to not
  1682. apply a WHERE ? LIKE ? clause.
  1683. keyvalues:
  1684. column names and values to select the rows with, or None to not
  1685. apply a WHERE key = value clause.
  1686. exclude_keyvalues:
  1687. column names and values to exclude rows with, or None to not
  1688. apply a WHERE key != value clause.
  1689. order_direction: Whether the results should be ordered "ASC" or "DESC".
  1690. Returns:
  1691. The result as a list of dictionaries.
  1692. """
  1693. if order_direction not in ["ASC", "DESC"]:
  1694. raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
  1695. where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
  1696. arg_list = [] # type: List[Any]
  1697. if filters:
  1698. where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
  1699. arg_list += list(filters.values())
  1700. where_clause += " AND " if filters and keyvalues else ""
  1701. if keyvalues:
  1702. where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
  1703. arg_list += list(keyvalues.values())
  1704. if exclude_keyvalues:
  1705. where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues)
  1706. arg_list += list(exclude_keyvalues.values())
  1707. sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
  1708. ", ".join(retcols),
  1709. table,
  1710. where_clause,
  1711. orderby,
  1712. order_direction,
  1713. )
  1714. txn.execute(sql, arg_list + [limit, start])
  1715. return cls.cursor_to_dict(txn)
  1716. async def simple_search_list(
  1717. self,
  1718. table: str,
  1719. term: Optional[str],
  1720. col: str,
  1721. retcols: Iterable[str],
  1722. desc="simple_search_list",
  1723. ) -> Optional[List[Dict[str, Any]]]:
  1724. """Executes a SELECT query on the named table, which may return zero or
  1725. more rows, returning the result as a list of dicts.
  1726. Args:
  1727. table: the table name
  1728. term: term for searching the table matched to a column.
  1729. col: column to query term should be matched to
  1730. retcols: the names of the columns to return
  1731. Returns:
  1732. A list of dictionaries or None.
  1733. """
  1734. return await self.runInteraction(
  1735. desc,
  1736. self.simple_search_list_txn,
  1737. table,
  1738. term,
  1739. col,
  1740. retcols,
  1741. db_autocommit=True,
  1742. )
  1743. @classmethod
  1744. def simple_search_list_txn(
  1745. cls,
  1746. txn: LoggingTransaction,
  1747. table: str,
  1748. term: Optional[str],
  1749. col: str,
  1750. retcols: Iterable[str],
  1751. ) -> Optional[List[Dict[str, Any]]]:
  1752. """Executes a SELECT query on the named table, which may return zero or
  1753. more rows, returning the result as a list of dicts.
  1754. Args:
  1755. txn: Transaction object
  1756. table: the table name
  1757. term: term for searching the table matched to a column.
  1758. col: column to query term should be matched to
  1759. retcols: the names of the columns to return
  1760. Returns:
  1761. None if no term is given, otherwise a list of dictionaries.
  1762. """
  1763. if term:
  1764. sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
  1765. termvalues = ["%%" + term + "%%"]
  1766. txn.execute(sql, termvalues)
  1767. else:
  1768. return None
  1769. return cls.cursor_to_dict(txn)
  1770. def make_in_list_sql_clause(
  1771. database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
  1772. ) -> Tuple[str, list]:
  1773. """Returns an SQL clause that checks the given column is in the iterable.
  1774. On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
  1775. it expands to `column = ANY(?)`. While both DBs support the `IN` form,
  1776. using the `ANY` form on postgres means that it views queries with
  1777. different length iterables as the same, helping the query stats.
  1778. Args:
  1779. database_engine
  1780. column: Name of the column
  1781. iterable: The values to check the column against.
  1782. Returns:
  1783. A tuple of SQL query and the args
  1784. """
  1785. if database_engine.supports_using_any_list:
  1786. # This should hopefully be faster, but also makes postgres query
  1787. # stats easier to understand.
  1788. return "%s = ANY(?)" % (column,), [list(iterable)]
  1789. else:
  1790. return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
  1791. KV = TypeVar("KV")
  1792. def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
  1793. """Returns a tuple comparison SQL clause
  1794. Builds a SQL clause that looks like "(a, b) > (?, ?)"
  1795. Args:
  1796. keys: A set of (column, value) pairs to be compared.
  1797. Returns:
  1798. A tuple of SQL query and the args
  1799. """
  1800. return (
  1801. "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
  1802. [k[1] for k in keys],
  1803. )