You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

81 lines
3.0 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional, Tuple, Union, cast
  16. from canonicaljson import encode_canonical_json
  17. from synapse.api.errors import Codes, SynapseError
  18. from synapse.storage._base import SQLBaseStore, db_to_json
  19. from synapse.storage.database import LoggingTransaction
  20. from synapse.types import JsonDict
  21. from synapse.util.caches.descriptors import cached
  22. class FilteringStore(SQLBaseStore):
  23. @cached(num_args=2)
  24. async def get_user_filter(
  25. self, user_localpart: str, filter_id: Union[int, str]
  26. ) -> JsonDict:
  27. # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
  28. # with a coherent error message rather than 500 M_UNKNOWN.
  29. try:
  30. int(filter_id)
  31. except ValueError:
  32. raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
  33. def_json = await self.db_pool.simple_select_one_onecol(
  34. table="user_filters",
  35. keyvalues={"user_id": user_localpart, "filter_id": filter_id},
  36. retcol="filter_json",
  37. allow_none=False,
  38. desc="get_user_filter",
  39. )
  40. return db_to_json(def_json)
  41. async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
  42. def_json = encode_canonical_json(user_filter)
  43. # Need an atomic transaction to SELECT the maximal ID so far then
  44. # INSERT a new one
  45. def _do_txn(txn: LoggingTransaction) -> int:
  46. sql = (
  47. "SELECT filter_id FROM user_filters "
  48. "WHERE user_id = ? AND filter_json = ?"
  49. )
  50. txn.execute(sql, (user_localpart, bytearray(def_json)))
  51. filter_id_response = txn.fetchone()
  52. if filter_id_response is not None:
  53. return filter_id_response[0]
  54. sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
  55. txn.execute(sql, (user_localpart,))
  56. max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
  57. if max_id is None:
  58. filter_id = 0
  59. else:
  60. filter_id = max_id + 1
  61. sql = (
  62. "INSERT INTO user_filters (user_id, filter_id, filter_json)"
  63. "VALUES(?, ?, ?)"
  64. )
  65. txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
  66. return filter_id
  67. return await self.db_pool.runInteraction("add_user_filter", _do_txn)