You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

225 lines
8.9 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
  16. import psycopg2.extensions
  17. from synapse.storage.engines._base import (
  18. BaseDatabaseEngine,
  19. IncorrectDatabaseSetup,
  20. IsolationLevel,
  21. )
  22. from synapse.storage.types import Cursor
  23. if TYPE_CHECKING:
  24. from synapse.storage.database import LoggingDatabaseConnection
  25. logger = logging.getLogger(__name__)
  26. class PostgresEngine(
  27. BaseDatabaseEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor]
  28. ):
  29. def __init__(self, database_config: Mapping[str, Any]):
  30. super().__init__(psycopg2, database_config)
  31. psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
  32. # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
  33. # actually want to use bytes than wrap it in `bytearray`.
  34. def _disable_bytes_adapter(_: bytes) -> NoReturn:
  35. raise Exception("Passing bytes to DB is disabled.")
  36. psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
  37. self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
  38. self._version: Optional[int] = None # unknown as yet
  39. self.isolation_level_map: Mapping[int, int] = {
  40. IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
  41. IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
  42. IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
  43. }
  44. self.default_isolation_level = (
  45. psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
  46. )
  47. self.config = database_config
  48. @property
  49. def single_threaded(self) -> bool:
  50. return False
  51. def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
  52. txn.execute(
  53. "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
  54. )
  55. collation, ctype = cast(Tuple[str, str], txn.fetchone())
  56. return collation, ctype
  57. def check_database(
  58. self,
  59. db_conn: psycopg2.extensions.connection,
  60. allow_outdated_version: bool = False,
  61. ) -> None:
  62. # Get the version of PostgreSQL that we're using. As per the psycopg2
  63. # docs: The number is formed by converting the major, minor, and
  64. # revision numbers into two-decimal-digit numbers and appending them
  65. # together. For example, version 8.1.5 will be returned as 80105
  66. self._version = cast(int, db_conn.server_version)
  67. allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
  68. # Are we on a supported PostgreSQL version?
  69. if not allow_outdated_version and self._version < 100000:
  70. raise RuntimeError("Synapse requires PostgreSQL 10 or above.")
  71. with db_conn.cursor() as txn:
  72. txn.execute("SHOW SERVER_ENCODING")
  73. rows = txn.fetchall()
  74. if rows and rows[0][0] != "UTF8":
  75. raise IncorrectDatabaseSetup(
  76. "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
  77. "See docs/postgres.md for more information." % (rows[0][0],)
  78. )
  79. collation, ctype = self.get_db_locale(txn)
  80. if collation != "C":
  81. logger.warning(
  82. "Database has incorrect collation of %r. Should be 'C'",
  83. collation,
  84. )
  85. if not allow_unsafe_locale:
  86. raise IncorrectDatabaseSetup(
  87. "Database has incorrect collation of %r. Should be 'C'\n"
  88. "See docs/postgres.md for more information. You can override this check by"
  89. "setting 'allow_unsafe_locale' to true in the database config.",
  90. collation,
  91. )
  92. if ctype != "C":
  93. if not allow_unsafe_locale:
  94. logger.warning(
  95. "Database has incorrect ctype of %r. Should be 'C'",
  96. ctype,
  97. )
  98. raise IncorrectDatabaseSetup(
  99. "Database has incorrect ctype of %r. Should be 'C'\n"
  100. "See docs/postgres.md for more information. You can override this check by"
  101. "setting 'allow_unsafe_locale' to true in the database config.",
  102. ctype,
  103. )
  104. def check_new_database(self, txn: Cursor) -> None:
  105. """Gets called when setting up a brand new database. This allows us to
  106. apply stricter checks on new databases versus existing database.
  107. """
  108. collation, ctype = self.get_db_locale(txn)
  109. errors = []
  110. if collation != "C":
  111. errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
  112. if ctype != "C":
  113. errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
  114. if errors:
  115. raise IncorrectDatabaseSetup(
  116. "Database is incorrectly configured:\n\n%s\n\n"
  117. "See docs/postgres.md for more information." % ("\n".join(errors))
  118. )
  119. def convert_param_style(self, sql: str) -> str:
  120. return sql.replace("?", "%s")
  121. def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
  122. db_conn.set_isolation_level(self.default_isolation_level)
  123. # Set the bytea output to escape, vs the default of hex
  124. cursor = db_conn.cursor()
  125. cursor.execute("SET bytea_output TO escape")
  126. # Asynchronous commit, don't wait for the server to call fsync before
  127. # ending the transaction.
  128. # https://www.postgresql.org/docs/current/static/wal-async-commit.html
  129. if not self.synchronous_commit:
  130. cursor.execute("SET synchronous_commit TO OFF")
  131. cursor.close()
  132. db_conn.commit()
  133. @property
  134. def supports_using_any_list(self) -> bool:
  135. """Do we support using `a = ANY(?)` and passing a list"""
  136. return True
  137. @property
  138. def supports_returning(self) -> bool:
  139. """Do we support the `RETURNING` clause in insert/update/delete?"""
  140. return True
  141. def is_deadlock(self, error: Exception) -> bool:
  142. if isinstance(error, psycopg2.DatabaseError):
  143. # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
  144. # "40001" serialization_failure
  145. # "40P01" deadlock_detected
  146. return error.pgcode in ["40001", "40P01"]
  147. return False
  148. def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
  149. return bool(conn.closed)
  150. def lock_table(self, txn: Cursor, table: str) -> None:
  151. txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
  152. @property
  153. def server_version(self) -> str:
  154. """Returns a string giving the server version. For example: '8.1.5'."""
  155. # note that this is a bit of a hack because it relies on check_database
  156. # having been called. Still, that should be a safe bet here.
  157. numver = self._version
  158. assert numver is not None
  159. # https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
  160. if numver >= 100000:
  161. return "%i.%i" % (numver / 10000, numver % 10000)
  162. else:
  163. return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
  164. def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
  165. return conn.status != psycopg2.extensions.STATUS_READY
  166. def attempt_to_set_autocommit(
  167. self, conn: psycopg2.extensions.connection, autocommit: bool
  168. ) -> None:
  169. return conn.set_session(autocommit=autocommit)
  170. def attempt_to_set_isolation_level(
  171. self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
  172. ) -> None:
  173. if isolation_level is None:
  174. isolation_level = self.default_isolation_level
  175. else:
  176. isolation_level = self.isolation_level_map[isolation_level]
  177. return conn.set_isolation_level(isolation_level)
  178. @staticmethod
  179. def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None:
  180. """Execute a chunk of SQL containing multiple semicolon-delimited statements.
  181. Psycopg2 seems happy to do this in DBAPI2's `execute()` function.
  182. """
  183. cursor.execute(script)