@@ -0,0 +1 @@ | |||
Add checks on startup that PostgreSQL sequences are consistent with their associated tables. |
@@ -106,6 +106,17 @@ Note that the above may fail with an error about duplicate rows if corruption | |||
has already occurred, and such duplicate rows will need to be manually removed. | |||
## Fixing inconsistent sequences error | |||
Synapse uses Postgres sequences to generate IDs for various tables. A sequence | |||
and associated table can get out of sync if, for example, Synapse has been | |||
downgraded and then upgraded again. | |||
To fix the issue shut down Synapse (including any and all workers) and run the | |||
SQL command included in the error message. Once done Synapse should start | |||
successfully. | |||
## Tuning Postgres | |||
The default settings should be fine for most deployments. For larger | |||
@@ -41,6 +41,9 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
self.config = hs.config | |||
self.clock = hs.get_clock() | |||
# Note: we don't check this sequence for consistency as we'd have to | |||
# call `find_max_generated_user_id_localpart` each time, which is | |||
# expensive if there are many entries. | |||
self._user_id_seq = build_sequence_generator( | |||
database.engine, find_max_generated_user_id_localpart, "user_id_seq", | |||
) | |||
@@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
self._state_group_seq_gen = build_sequence_generator( | |||
self.database_engine, get_max_state_group_txn, "state_group_id_seq" | |||
) | |||
self._state_group_seq_gen.check_consistency( | |||
db_conn, table="state_groups", id_column="id" | |||
) | |||
@cached(max_entries=10000, iterable=True) | |||
async def get_state_group_delta(self, state_group): | |||
@@ -258,6 +258,11 @@ class MultiWriterIdGenerator: | |||
self._sequence_gen = PostgresSequenceGenerator(sequence_name) | |||
# We check that the table and sequence haven't diverged. | |||
self._sequence_gen.check_consistency( | |||
db_conn, table=table, id_column=id_column, positive=positive | |||
) | |||
# This goes and fills out the above state from the database. | |||
self._load_current_ids(db_conn, table, instance_column, id_column) | |||
@@ -13,11 +13,34 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import abc | |||
import logging | |||
import threading | |||
from typing import Callable, List, Optional | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
from synapse.storage.types import Cursor | |||
from synapse.storage.engines import ( | |||
BaseDatabaseEngine, | |||
IncorrectDatabaseSetup, | |||
PostgresEngine, | |||
) | |||
from synapse.storage.types import Connection, Cursor | |||
logger = logging.getLogger(__name__) | |||
_INCONSISTENT_SEQUENCE_ERROR = """ | |||
Postgres sequence '%(seq)s' is inconsistent with associated | |||
table '%(table)s'. This can happen if Synapse has been downgraded and | |||
then upgraded again, or due to a bad migration. | |||
To fix this error, shut down Synapse (including any and all workers) | |||
and run the following SQL: | |||
SELECT setval('%(seq)s', ( | |||
%(max_id_sql)s | |||
)); | |||
See docs/postgres.md for more information. | |||
""" | |||
class SequenceGenerator(metaclass=abc.ABCMeta): | |||
@@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta): | |||
"""Gets the next ID in the sequence""" | |||
... | |||
@abc.abstractmethod | |||
def check_consistency( | |||
self, db_conn: Connection, table: str, id_column: str, positive: bool = True | |||
): | |||
"""Should be called during start up to test that the current value of | |||
the sequence is greater than or equal to the maximum ID in the table. | |||
This is to handle various cases where the sequence value can get out | |||
of sync with the table, e.g. if Synapse gets rolled back to a previous | |||
version and the rolled forwards again. | |||
""" | |||
... | |||
class PostgresSequenceGenerator(SequenceGenerator): | |||
"""An implementation of SequenceGenerator which uses a postgres sequence""" | |||
@@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator): | |||
) | |||
return [i for (i,) in txn] | |||
def check_consistency( | |||
self, db_conn: Connection, table: str, id_column: str, positive: bool = True | |||
): | |||
txn = db_conn.cursor() | |||
# First we get the current max ID from the table. | |||
table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { | |||
"id": id_column, | |||
"table": table, | |||
"agg": "MAX" if positive else "-MIN", | |||
} | |||
txn.execute(table_sql) | |||
row = txn.fetchone() | |||
if not row: | |||
# Table is empty, so nothing to do. | |||
txn.close() | |||
return | |||
# Now we fetch the current value from the sequence and compare with the | |||
# above. | |||
max_stream_id = row[0] | |||
txn.execute( | |||
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} | |||
) | |||
last_value, is_called = txn.fetchone() | |||
txn.close() | |||
# If `is_called` is False then `last_value` is actually the value that | |||
# will be generated next, so we decrement to get the true "last value". | |||
if not is_called: | |||
last_value -= 1 | |||
if max_stream_id > last_value: | |||
logger.warning( | |||
"Postgres sequence %s is behind table %s: %d < %d", | |||
last_value, | |||
max_stream_id, | |||
) | |||
raise IncorrectDatabaseSetup( | |||
_INCONSISTENT_SEQUENCE_ERROR | |||
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} | |||
) | |||
GetFirstCallbackType = Callable[[Cursor], int] | |||
@@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator): | |||
self._current_max_id += 1 | |||
return self._current_max_id | |||
def check_consistency( | |||
self, db_conn: Connection, table: str, id_column: str, positive: bool = True | |||
): | |||
# There is nothing to do for in memory sequences | |||
pass | |||
def build_sequence_generator( | |||
database_engine: BaseDatabaseEngine, | |||
@@ -12,9 +12,8 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.engines import IncorrectDatabaseSetup | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||
from tests.unittest import HomeserverTestCase | |||
@@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): | |||
writers=writers, | |||
) | |||
return self.get_success(self.db_pool.runWithConnection(_create)) | |||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) | |||
def _insert_rows(self, instance_name: str, number: int): | |||
"""Insert N rows as the given instance, inserting with stream IDs pulled | |||
@@ -411,6 +410,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): | |||
self.get_success(_get_next_async()) | |||
self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) | |||
def test_sequence_consistency(self): | |||
"""Test that we error out if the table and sequence diverges. | |||
""" | |||
# Prefill with some rows | |||
self._insert_row_with_id("master", 3) | |||
# Now we add a row *without* updating the stream ID | |||
def _insert(txn): | |||
txn.execute("INSERT INTO foobar VALUES (26, 'master')") | |||
self.get_success(self.db_pool.runInteraction("_insert", _insert)) | |||
# Creating the ID gen should error | |||
with self.assertRaises(IncorrectDatabaseSetup): | |||
self._create_id_generator("first") | |||
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): | |||
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs. | |||
@@ -14,7 +14,6 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import gc | |||
import hashlib | |||
import hmac | |||
@@ -28,6 +27,7 @@ from mock import Mock, patch | |||
from canonicaljson import json | |||
from twisted.internet.defer import Deferred, ensureDeferred, succeed | |||
from twisted.python.failure import Failure | |||
from twisted.python.threadpool import ThreadPool | |||
from twisted.trial import unittest | |||
@@ -476,6 +476,35 @@ class HomeserverTestCase(TestCase): | |||
self.pump() | |||
return self.failureResultOf(d, exc) | |||
def get_success_or_raise(self, d, by=0.0): | |||
"""Drive deferred to completion and return result or raise exception | |||
on failure. | |||
""" | |||
if inspect.isawaitable(d): | |||
deferred = ensureDeferred(d) | |||
if not isinstance(deferred, Deferred): | |||
return d | |||
results = [] # type: list | |||
deferred.addBoth(results.append) | |||
self.pump(by=by) | |||
if not results: | |||
self.fail( | |||
"Success result expected on {!r}, found no result instead".format( | |||
deferred | |||
) | |||
) | |||
result = results[0] | |||
if isinstance(result, Failure): | |||
result.raiseException() | |||
return result | |||
def register_user(self, username, password, admin=False): | |||
""" | |||
Register a user. Requires the Admin API be registered. | |||