You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

920 lines
31 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import logging
  16. import re
  17. from collections import deque
  18. from dataclasses import dataclass
  19. from typing import (
  20. TYPE_CHECKING,
  21. Any,
  22. Collection,
  23. Iterable,
  24. List,
  25. Optional,
  26. Set,
  27. Tuple,
  28. Union,
  29. cast,
  30. )
  31. import attr
  32. from synapse.api.errors import SynapseError
  33. from synapse.events import EventBase
  34. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  35. from synapse.storage.database import (
  36. DatabasePool,
  37. LoggingDatabaseConnection,
  38. LoggingTransaction,
  39. )
  40. from synapse.storage.databases.main.events_worker import EventRedactBehaviour
  41. from synapse.storage.engines import PostgresEngine, Sqlite3Engine
  42. from synapse.types import JsonDict
  43. if TYPE_CHECKING:
  44. from synapse.server import HomeServer
  45. logger = logging.getLogger(__name__)
  46. @attr.s(slots=True, frozen=True, auto_attribs=True)
  47. class SearchEntry:
  48. key: str
  49. value: str
  50. event_id: str
  51. room_id: str
  52. stream_ordering: Optional[int]
  53. origin_server_ts: int
  54. def _clean_value_for_search(value: str) -> str:
  55. """
  56. Replaces any null code points in the string with spaces as
  57. Postgres and SQLite do not like the insertion of strings with
  58. null code points into the full-text search tables.
  59. """
  60. return value.replace("\u0000", " ")
  61. class SearchWorkerStore(SQLBaseStore):
  62. def store_search_entries_txn(
  63. self, txn: LoggingTransaction, entries: Iterable[SearchEntry]
  64. ) -> None:
  65. """Add entries to the search table
  66. Args:
  67. txn:
  68. entries: entries to be added to the table
  69. """
  70. if not self.hs.config.server.enable_search:
  71. return
  72. if isinstance(self.database_engine, PostgresEngine):
  73. sql = """
  74. INSERT INTO event_search
  75. (event_id, room_id, key, vector, stream_ordering, origin_server_ts)
  76. VALUES (?,?,?,to_tsvector('english', ?),?,?)
  77. """
  78. args1 = (
  79. (
  80. entry.event_id,
  81. entry.room_id,
  82. entry.key,
  83. _clean_value_for_search(entry.value),
  84. entry.stream_ordering,
  85. entry.origin_server_ts,
  86. )
  87. for entry in entries
  88. )
  89. txn.execute_batch(sql, args1)
  90. elif isinstance(self.database_engine, Sqlite3Engine):
  91. self.db_pool.simple_insert_many_txn(
  92. txn,
  93. table="event_search",
  94. keys=("event_id", "room_id", "key", "value"),
  95. values=[
  96. (
  97. entry.event_id,
  98. entry.room_id,
  99. entry.key,
  100. _clean_value_for_search(entry.value),
  101. )
  102. for entry in entries
  103. ],
  104. )
  105. else:
  106. # This should be unreachable.
  107. raise Exception("Unrecognized database engine")
  108. class SearchBackgroundUpdateStore(SearchWorkerStore):
  109. EVENT_SEARCH_UPDATE_NAME = "event_search"
  110. EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
  111. EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
  112. EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
  113. def __init__(
  114. self,
  115. database: DatabasePool,
  116. db_conn: LoggingDatabaseConnection,
  117. hs: "HomeServer",
  118. ):
  119. super().__init__(database, db_conn, hs)
  120. self.db_pool.updates.register_background_update_handler(
  121. self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
  122. )
  123. self.db_pool.updates.register_background_update_handler(
  124. self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
  125. )
  126. self.db_pool.updates.register_background_update_handler(
  127. self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
  128. )
  129. self.db_pool.updates.register_background_update_handler(
  130. self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
  131. )
  132. async def _background_reindex_search(
  133. self, progress: JsonDict, batch_size: int
  134. ) -> int:
  135. # we work through the events table from highest stream id to lowest
  136. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  137. max_stream_id = progress["max_stream_id_exclusive"]
  138. rows_inserted = progress.get("rows_inserted", 0)
  139. TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
  140. def reindex_search_txn(txn: LoggingTransaction) -> int:
  141. sql = """
  142. SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
  143. FROM events
  144. JOIN event_json USING (room_id, event_id)
  145. WHERE ? <= stream_ordering AND stream_ordering < ?
  146. AND (%s)
  147. ORDER BY stream_ordering DESC
  148. LIMIT ?
  149. """ % (
  150. " OR ".join("type = '%s'" % (t,) for t in TYPES),
  151. )
  152. txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
  153. # we could stream straight from the results into
  154. # store_search_entries_txn with a generator function, but that
  155. # would mean having two cursors open on the database at once.
  156. # Instead we just build a list of results.
  157. rows = txn.fetchall()
  158. if not rows:
  159. return 0
  160. min_stream_id = rows[-1][0]
  161. event_search_rows = []
  162. for (
  163. stream_ordering,
  164. event_id,
  165. room_id,
  166. etype,
  167. json,
  168. origin_server_ts,
  169. ) in rows:
  170. try:
  171. try:
  172. event_json = db_to_json(json)
  173. content = event_json["content"]
  174. except Exception:
  175. continue
  176. if etype == "m.room.message":
  177. key = "content.body"
  178. value = content["body"]
  179. elif etype == "m.room.topic":
  180. key = "content.topic"
  181. value = content["topic"]
  182. elif etype == "m.room.name":
  183. key = "content.name"
  184. value = content["name"]
  185. else:
  186. raise Exception("unexpected event type %s" % etype)
  187. except (KeyError, AttributeError):
  188. # If the event is missing a necessary field then
  189. # skip over it.
  190. continue
  191. if not isinstance(value, str):
  192. # If the event body, name or topic isn't a string
  193. # then skip over it
  194. continue
  195. event_search_rows.append(
  196. SearchEntry(
  197. key=key,
  198. value=value,
  199. event_id=event_id,
  200. room_id=room_id,
  201. stream_ordering=stream_ordering,
  202. origin_server_ts=origin_server_ts,
  203. )
  204. )
  205. self.store_search_entries_txn(txn, event_search_rows)
  206. progress = {
  207. "target_min_stream_id_inclusive": target_min_stream_id,
  208. "max_stream_id_exclusive": min_stream_id,
  209. "rows_inserted": rows_inserted + len(event_search_rows),
  210. }
  211. self.db_pool.updates._background_update_progress_txn(
  212. txn, self.EVENT_SEARCH_UPDATE_NAME, progress
  213. )
  214. return len(event_search_rows)
  215. if self.hs.config.server.enable_search:
  216. result = await self.db_pool.runInteraction(
  217. self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
  218. )
  219. else:
  220. # Don't index anything if search is not enabled.
  221. result = 0
  222. if not result:
  223. await self.db_pool.updates._end_background_update(
  224. self.EVENT_SEARCH_UPDATE_NAME
  225. )
  226. return result
  227. async def _background_reindex_gin_search(
  228. self, progress: JsonDict, batch_size: int
  229. ) -> int:
  230. """This handles old synapses which used GIST indexes, if any;
  231. converting them back to be GIN as per the actual schema.
  232. """
  233. def create_index(conn: LoggingDatabaseConnection) -> None:
  234. conn.rollback()
  235. # we have to set autocommit, because postgres refuses to
  236. # CREATE INDEX CONCURRENTLY without it.
  237. conn.engine.attempt_to_set_autocommit(conn.conn, True)
  238. try:
  239. c = conn.cursor()
  240. # if we skipped the conversion to GIST, we may already/still
  241. # have an event_search_fts_idx; unfortunately postgres 9.4
  242. # doesn't support CREATE INDEX IF EXISTS so we just catch the
  243. # exception and ignore it.
  244. import psycopg2
  245. try:
  246. c.execute(
  247. """
  248. CREATE INDEX CONCURRENTLY event_search_fts_idx
  249. ON event_search USING GIN (vector)
  250. """
  251. )
  252. except psycopg2.ProgrammingError as e:
  253. logger.warning(
  254. "Ignoring error %r when trying to switch from GIST to GIN", e
  255. )
  256. # we should now be able to delete the GIST index.
  257. c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
  258. finally:
  259. conn.engine.attempt_to_set_autocommit(conn.conn, False)
  260. if isinstance(self.database_engine, PostgresEngine):
  261. await self.db_pool.runWithConnection(create_index)
  262. await self.db_pool.updates._end_background_update(
  263. self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
  264. )
  265. return 1
  266. async def _background_reindex_search_order(
  267. self, progress: JsonDict, batch_size: int
  268. ) -> int:
  269. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  270. max_stream_id = progress["max_stream_id_exclusive"]
  271. rows_inserted = progress.get("rows_inserted", 0)
  272. have_added_index = progress["have_added_indexes"]
  273. if not have_added_index:
  274. def create_index(conn: LoggingDatabaseConnection) -> None:
  275. conn.rollback()
  276. conn.engine.attempt_to_set_autocommit(conn.conn, True)
  277. c = conn.cursor()
  278. # We create with NULLS FIRST so that when we search *backwards*
  279. # we get the ones with non null origin_server_ts *first*
  280. c.execute(
  281. """
  282. CREATE INDEX CONCURRENTLY event_search_room_order
  283. ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
  284. """
  285. )
  286. c.execute(
  287. """
  288. CREATE INDEX CONCURRENTLY event_search_order
  289. ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
  290. """
  291. )
  292. conn.engine.attempt_to_set_autocommit(conn.conn, False)
  293. await self.db_pool.runWithConnection(create_index)
  294. pg = dict(progress)
  295. pg["have_added_indexes"] = True
  296. await self.db_pool.runInteraction(
  297. self.EVENT_SEARCH_ORDER_UPDATE_NAME,
  298. self.db_pool.updates._background_update_progress_txn,
  299. self.EVENT_SEARCH_ORDER_UPDATE_NAME,
  300. pg,
  301. )
  302. def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
  303. sql = """
  304. UPDATE event_search AS es
  305. SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
  306. FROM events AS e
  307. WHERE e.event_id = es.event_id
  308. AND ? <= e.stream_ordering AND e.stream_ordering < ?
  309. RETURNING es.stream_ordering
  310. """
  311. min_stream_id = max_stream_id - batch_size
  312. txn.execute(sql, (min_stream_id, max_stream_id))
  313. rows = txn.fetchall()
  314. if min_stream_id < target_min_stream_id:
  315. # We've recached the end.
  316. return len(rows), False
  317. progress = {
  318. "target_min_stream_id_inclusive": target_min_stream_id,
  319. "max_stream_id_exclusive": min_stream_id,
  320. "rows_inserted": rows_inserted + len(rows),
  321. "have_added_indexes": True,
  322. }
  323. self.db_pool.updates._background_update_progress_txn(
  324. txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
  325. )
  326. return len(rows), True
  327. num_rows, finished = await self.db_pool.runInteraction(
  328. self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
  329. )
  330. if not finished:
  331. await self.db_pool.updates._end_background_update(
  332. self.EVENT_SEARCH_ORDER_UPDATE_NAME
  333. )
  334. return num_rows
  335. async def _background_delete_non_strings(
  336. self, progress: JsonDict, batch_size: int
  337. ) -> int:
  338. """Deletes rows with non-string `value`s from `event_search` if using sqlite.
  339. Prior to Synapse 1.44.0, malformed events received over federation could cause integers
  340. to be inserted into the `event_search` table when using sqlite.
  341. """
  342. def delete_non_strings_txn(txn: LoggingTransaction) -> None:
  343. txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'")
  344. await self.db_pool.runInteraction(
  345. self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn
  346. )
  347. await self.db_pool.updates._end_background_update(
  348. self.EVENT_SEARCH_DELETE_NON_STRINGS
  349. )
  350. return 1
  351. class SearchStore(SearchBackgroundUpdateStore):
  352. def __init__(
  353. self,
  354. database: DatabasePool,
  355. db_conn: LoggingDatabaseConnection,
  356. hs: "HomeServer",
  357. ):
  358. super().__init__(database, db_conn, hs)
  359. async def search_msgs(
  360. self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
  361. ) -> JsonDict:
  362. """Performs a full text search over events with given keys.
  363. Args:
  364. room_ids: List of room ids to search in
  365. search_term: Search term to search for
  366. keys: List of keys to search in, currently supports
  367. "content.body", "content.name", "content.topic"
  368. Returns:
  369. Dictionary of results
  370. """
  371. clauses = []
  372. args: List[Any] = []
  373. # Make sure we don't explode because the person is in too many rooms.
  374. # We filter the results below regardless.
  375. if len(room_ids) < 500:
  376. clause, args = make_in_list_sql_clause(
  377. self.database_engine, "room_id", room_ids
  378. )
  379. clauses = [clause]
  380. local_clauses = []
  381. for key in keys:
  382. local_clauses.append("key = ?")
  383. args.append(key)
  384. clauses.append("(%s)" % (" OR ".join(local_clauses),))
  385. count_args = args
  386. count_clauses = clauses
  387. if isinstance(self.database_engine, PostgresEngine):
  388. search_query = search_term
  389. sql = """
  390. SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank,
  391. room_id, event_id
  392. FROM event_search
  393. WHERE vector @@ websearch_to_tsquery('english', ?)
  394. """
  395. args = [search_query, search_query] + args
  396. count_sql = """
  397. SELECT room_id, count(*) as count FROM event_search
  398. WHERE vector @@ websearch_to_tsquery('english', ?)
  399. """
  400. count_args = [search_query] + count_args
  401. elif isinstance(self.database_engine, Sqlite3Engine):
  402. search_query = _parse_query_for_sqlite(search_term)
  403. sql = """
  404. SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
  405. FROM event_search
  406. WHERE value MATCH ?
  407. """
  408. args = [search_query] + args
  409. count_sql = """
  410. SELECT room_id, count(*) as count FROM event_search
  411. WHERE value MATCH ?
  412. """
  413. count_args = [search_query] + count_args
  414. else:
  415. # This should be unreachable.
  416. raise Exception("Unrecognized database engine")
  417. for clause in clauses:
  418. sql += " AND " + clause
  419. for clause in count_clauses:
  420. count_sql += " AND " + clause
  421. # We add an arbitrary limit here to ensure we don't try to pull the
  422. # entire table from the database.
  423. sql += " ORDER BY rank DESC LIMIT 500"
  424. # List of tuples of (rank, room_id, event_id).
  425. results = cast(
  426. List[Tuple[Union[int, float], str, str]],
  427. await self.db_pool.execute("search_msgs", sql, *args),
  428. )
  429. results = list(filter(lambda row: row[1] in room_ids, results))
  430. # We set redact_behaviour to block here to prevent redacted events being returned in
  431. # search results (which is a data leak)
  432. events = await self.get_events_as_list( # type: ignore[attr-defined]
  433. [r[2] for r in results],
  434. redact_behaviour=EventRedactBehaviour.block,
  435. )
  436. event_map = {ev.event_id: ev for ev in events}
  437. highlights = None
  438. if isinstance(self.database_engine, PostgresEngine):
  439. highlights = await self._find_highlights_in_postgres(search_query, events)
  440. count_sql += " GROUP BY room_id"
  441. # List of tuples of (room_id, count).
  442. count_results = cast(
  443. List[Tuple[str, int]],
  444. await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
  445. )
  446. count = sum(row[1] for row in count_results if row[0] in room_ids)
  447. return {
  448. "results": [
  449. {"event": event_map[r[2]], "rank": r[0]}
  450. for r in results
  451. if r[2] in event_map
  452. ],
  453. "highlights": highlights,
  454. "count": count,
  455. }
  456. async def search_rooms(
  457. self,
  458. room_ids: Collection[str],
  459. search_term: str,
  460. keys: Iterable[str],
  461. limit: int,
  462. pagination_token: Optional[str] = None,
  463. ) -> JsonDict:
  464. """Performs a full text search over events with given keys.
  465. Args:
  466. room_ids: The room_ids to search in
  467. search_term: Search term to search for
  468. keys: List of keys to search in, currently supports "content.body",
  469. "content.name", "content.topic"
  470. pagination_token: A pagination token previously returned
  471. Returns:
  472. Each match as a dictionary.
  473. """
  474. clauses = []
  475. args: List[Any] = []
  476. # Make sure we don't explode because the person is in too many rooms.
  477. # We filter the results below regardless.
  478. if len(room_ids) < 500:
  479. clause, args = make_in_list_sql_clause(
  480. self.database_engine, "room_id", room_ids
  481. )
  482. clauses = [clause]
  483. local_clauses = []
  484. for key in keys:
  485. local_clauses.append("key = ?")
  486. args.append(key)
  487. clauses.append("(%s)" % (" OR ".join(local_clauses),))
  488. # take copies of the current args and clauses lists, before adding
  489. # pagination clauses to main query.
  490. count_args = list(args)
  491. count_clauses = list(clauses)
  492. if pagination_token:
  493. try:
  494. origin_server_ts_str, stream_str = pagination_token.split(",")
  495. origin_server_ts = int(origin_server_ts_str)
  496. stream = int(stream_str)
  497. except Exception:
  498. raise SynapseError(400, "Invalid pagination token")
  499. clauses.append(
  500. """
  501. (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
  502. """
  503. )
  504. args.extend([origin_server_ts, origin_server_ts, stream])
  505. if isinstance(self.database_engine, PostgresEngine):
  506. search_query = search_term
  507. sql = """
  508. SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
  509. room_id, event_id, origin_server_ts, stream_ordering
  510. FROM event_search
  511. WHERE vector @@ websearch_to_tsquery('english', ?) AND
  512. """
  513. args = [search_query, search_query] + args
  514. count_sql = """
  515. SELECT room_id, count(*) as count FROM event_search
  516. WHERE vector @@ websearch_to_tsquery('english', ?) AND
  517. """
  518. count_args = [search_query] + count_args
  519. elif isinstance(self.database_engine, Sqlite3Engine):
  520. # We use CROSS JOIN here to ensure we use the right indexes.
  521. # https://sqlite.org/optoverview.html#crossjoin
  522. #
  523. # We want to use the full text search index on event_search to
  524. # extract all possible matches first, then lookup those matches
  525. # in the events table to get the topological ordering. We need
  526. # to use the indexes in this order because sqlite refuses to
  527. # MATCH unless it uses the full text search index
  528. sql = """
  529. SELECT
  530. rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
  531. FROM (
  532. SELECT key, event_id, matchinfo(event_search) as matchinfo
  533. FROM event_search
  534. WHERE value MATCH ?
  535. )
  536. CROSS JOIN events USING (event_id)
  537. WHERE
  538. """
  539. search_query = _parse_query_for_sqlite(search_term)
  540. args = [search_query] + args
  541. count_sql = """
  542. SELECT room_id, count(*) as count FROM event_search
  543. WHERE value MATCH ? AND
  544. """
  545. count_args = [search_query] + count_args
  546. else:
  547. # This should be unreachable.
  548. raise Exception("Unrecognized database engine")
  549. sql += " AND ".join(clauses)
  550. count_sql += " AND ".join(count_clauses)
  551. # We add an arbitrary limit here to ensure we don't try to pull the
  552. # entire table from the database.
  553. if isinstance(self.database_engine, PostgresEngine):
  554. sql += """
  555. ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
  556. LIMIT ?
  557. """
  558. elif isinstance(self.database_engine, Sqlite3Engine):
  559. sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
  560. else:
  561. raise Exception("Unrecognized database engine")
  562. # mypy expects to append only a `str`, not an `int`
  563. args.append(limit)
  564. # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
  565. results = cast(
  566. List[Tuple[Union[int, float], str, str, int, int]],
  567. await self.db_pool.execute("search_rooms", sql, *args),
  568. )
  569. results = list(filter(lambda row: row[1] in room_ids, results))
  570. # We set redact_behaviour to block here to prevent redacted events being returned in
  571. # search results (which is a data leak)
  572. events = await self.get_events_as_list( # type: ignore[attr-defined]
  573. [r[2] for r in results],
  574. redact_behaviour=EventRedactBehaviour.block,
  575. )
  576. event_map = {ev.event_id: ev for ev in events}
  577. highlights = None
  578. if isinstance(self.database_engine, PostgresEngine):
  579. highlights = await self._find_highlights_in_postgres(search_query, events)
  580. count_sql += " GROUP BY room_id"
  581. # List of tuples of (room_id, count).
  582. count_results = cast(
  583. List[Tuple[str, int]],
  584. await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
  585. )
  586. count = sum(row[1] for row in count_results if row[0] in room_ids)
  587. return {
  588. "results": [
  589. {
  590. "event": event_map[r[2]],
  591. "rank": r[0],
  592. "pagination_token": "%s,%s" % (r[3], r[4]),
  593. }
  594. for r in results
  595. if r[2] in event_map
  596. ],
  597. "highlights": highlights,
  598. "count": count,
  599. }
  600. async def _find_highlights_in_postgres(
  601. self, search_query: str, events: List[EventBase]
  602. ) -> Set[str]:
  603. """Given a list of events and a search term, return a list of words
  604. that match from the content of the event.
  605. This is used to give a list of words that clients can match against to
  606. highlight the matching parts.
  607. Args:
  608. search_query
  609. events: A list of events
  610. Returns:
  611. A set of strings.
  612. """
  613. def f(txn: LoggingTransaction) -> Set[str]:
  614. highlight_words = set()
  615. for event in events:
  616. # As a hack we simply join values of all possible keys. This is
  617. # fine since we're only using them to find possible highlights.
  618. values = []
  619. for key in ("body", "name", "topic"):
  620. v = event.content.get(key, None)
  621. if v:
  622. v = _clean_value_for_search(v)
  623. values.append(v)
  624. if not values:
  625. continue
  626. value = " ".join(values)
  627. # We need to find some values for StartSel and StopSel that
  628. # aren't in the value so that we can pick results out.
  629. start_sel = "<"
  630. stop_sel = ">"
  631. while start_sel in value:
  632. start_sel += "<"
  633. while stop_sel in value:
  634. stop_sel += ">"
  635. query = (
  636. "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)"
  637. % (
  638. _to_postgres_options(
  639. {
  640. "StartSel": start_sel,
  641. "StopSel": stop_sel,
  642. "MaxFragments": "50",
  643. }
  644. )
  645. )
  646. )
  647. txn.execute(query, (value, search_query))
  648. (headline,) = txn.fetchall()[0]
  649. # Now we need to pick the possible highlights out of the haedline
  650. # result.
  651. matcher_regex = "%s(.*?)%s" % (
  652. re.escape(start_sel),
  653. re.escape(stop_sel),
  654. )
  655. res = re.findall(matcher_regex, headline)
  656. highlight_words.update([r.lower() for r in res])
  657. return highlight_words
  658. return await self.db_pool.runInteraction("_find_highlights", f)
  659. def _to_postgres_options(options_dict: JsonDict) -> str:
  660. return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
  661. @dataclass
  662. class Phrase:
  663. phrase: List[str]
  664. class SearchToken(enum.Enum):
  665. Not = enum.auto()
  666. Or = enum.auto()
  667. And = enum.auto()
  668. Token = Union[str, Phrase, SearchToken]
  669. TokenList = List[Token]
  670. def _is_stop_word(word: str) -> bool:
  671. # TODO Pull these out of the dictionary:
  672. # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
  673. return word in {"the", "a", "you", "me", "and", "but"}
  674. def _tokenize_query(query: str) -> TokenList:
  675. """
  676. Convert the user-supplied `query` into a TokenList, which can be translated into
  677. some DB-specific syntax.
  678. The following constructs are supported:
  679. - phrase queries using "double quotes"
  680. - case-insensitive `or` and `and` operators
  681. - negation of a keyword via unary `-`
  682. - unary hyphen to denote NOT e.g. 'include -exclude'
  683. The following differs from websearch_to_tsquery:
  684. - Stop words are not removed.
  685. - Unclosed phrases are treated differently.
  686. """
  687. tokens: TokenList = []
  688. # Find phrases.
  689. in_phrase = False
  690. parts = deque(query.split('"'))
  691. for i, part in enumerate(parts):
  692. # The contents inside double quotes is treated as a phrase.
  693. in_phrase = bool(i % 2)
  694. # Pull out the individual words, discarding any non-word characters.
  695. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
  696. # Phrases have simplified handling of words.
  697. if in_phrase:
  698. # Skip stop words.
  699. phrase = [word for word in words if not _is_stop_word(word)]
  700. # Consecutive words are implicitly ANDed together.
  701. if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
  702. tokens.append(SearchToken.And)
  703. # Add the phrase.
  704. tokens.append(Phrase(phrase))
  705. continue
  706. # Otherwise, not in a phrase.
  707. while words:
  708. word = words.popleft()
  709. if word.startswith("-"):
  710. tokens.append(SearchToken.Not)
  711. # If there's more word, put it back to be processed again.
  712. word = word[1:]
  713. if word:
  714. words.appendleft(word)
  715. elif word.lower() == "or":
  716. tokens.append(SearchToken.Or)
  717. else:
  718. # Skip stop words.
  719. if _is_stop_word(word):
  720. continue
  721. # Consecutive words are implicitly ANDed together.
  722. if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
  723. tokens.append(SearchToken.And)
  724. # Add the search term.
  725. tokens.append(word)
  726. return tokens
  727. def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
  728. """
  729. Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
  730. Assume sqlite was compiled with enhanced query syntax.
  731. Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
  732. """
  733. match_query = []
  734. for token in tokens:
  735. if isinstance(token, str):
  736. match_query.append(token)
  737. elif isinstance(token, Phrase):
  738. match_query.append('"' + " ".join(token.phrase) + '"')
  739. elif token == SearchToken.Not:
  740. # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
  741. # term has already been added before this.
  742. match_query.append(" NOT ")
  743. elif token == SearchToken.Or:
  744. match_query.append(" OR ")
  745. elif token == SearchToken.And:
  746. match_query.append(" AND ")
  747. else:
  748. raise ValueError(f"unknown token {token}")
  749. return "".join(match_query)
  750. def _parse_query_for_sqlite(search_term: str) -> str:
  751. """Takes a plain unicode string from the user and converts it into a form
  752. that can be passed to sqllite's matchinfo().
  753. """
  754. return _tokens_to_sqlite_match_query(_tokenize_query(search_term))