Przeglądaj źródła

Make scripts/ and scripts-dev/ pass pyflakes (and the rest of the codebase on py3) (#4068)

tags/v0.33.8rc1
Amber Brown 5 lat temu
committed by GitHub
rodzic
commit
e1728dfcbe
Nie znaleziono w bazie danych klucza dla tego podpisu ID klucza GPG: 4AEE18F83AFDEB23
27 zmienionych plików z 511 dodań i 518 usunięć
  1. +1
    -1
      .travis.yml
  2. +1
    -0
      changelog.d/4068.misc
  3. +15
    -21
      scripts-dev/check_auth.py
  4. +18
    -14
      scripts-dev/check_event_hash.py
  5. +19
    -17
      scripts-dev/check_signature.py
  6. +22
    -18
      scripts-dev/convert_server_keys.py
  7. +34
    -20
      scripts-dev/definitions.py
  8. +8
    -5
      scripts-dev/dump_macaroon.py
  9. +48
    -51
      scripts-dev/federation_client.py
  10. +38
    -24
      scripts-dev/hash_history.py
  11. +8
    -8
      scripts-dev/list_url_patterns.py
  12. +12
    -12
      scripts-dev/tail-synapse.py
  13. +1
    -4
      scripts/hash_password
  14. +12
    -24
      scripts/move_remote_media_to_new_store.py
  15. +44
    -32
      scripts/register_new_matrix_user
  16. +122
    -150
      scripts/synapse_port_db
  17. +1
    -1
      synapse/app/homeserver.py
  18. +1
    -1
      synapse/config/__main__.py
  19. +59
    -62
      synapse/config/_base.py
  20. +2
    -2
      synapse/storage/_base.py
  21. +1
    -1
      synapse/storage/keys.py
  22. +1
    -1
      synapse/storage/pusher.py
  23. +1
    -1
      synapse/storage/signatures.py
  24. +1
    -1
      synapse/storage/transactions.py
  25. +3
    -1
      synapse/util/caches/stream_change_cache.py
  26. +36
    -44
      synctl
  27. +2
    -2
      tox.ini

+ 1
- 1
.travis.yml Wyświetl plik

@@ -14,7 +14,7 @@ matrix:
- python: 2.7
env: TOX_ENV=packaging

- python: 2.7
- python: 3.6
env: TOX_ENV=pep8

- python: 2.7


+ 1
- 0
changelog.d/4068.misc Wyświetl plik

@@ -0,0 +1 @@
Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8.

+ 15
- 21
scripts-dev/check_auth.py Wyświetl plik

@@ -1,21 +1,20 @@
from synapse.events import FrozenEvent
from synapse.api.auth import Auth

from mock import Mock
from __future__ import print_function

import argparse
import itertools
import json
import sys

from mock import Mock

from synapse.api.auth import Auth
from synapse.events import FrozenEvent


def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth)

auth_map = {
e.event_id: e
for e in auth_chain
}
auth_map = {e.event_id: e for e in auth_chain}

create_events = {}
for e in auth_chain:
@@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events):
for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events]

auth_events = {
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events = {(e.type, e.state_key): e for e in auth_events_list}

auth_events[("m.room.create", "")] = create_events[e.room_id]

try:
auth.check(e, auth_events=auth_events)
except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key
print "Auth_events:", auth_events
print ex
print json.dumps(e.get_dict(), sort_keys=True, indent=4)
print("Failed:", e.event_id, e.type, e.state_key)
print("Auth_events:", auth_events)
print(ex)
print(json.dumps(e.get_dict(), sort_keys=True, indent=4))
# raise
print "Success:", e.event_id, e.type, e.state_key
print("Success:", e.event_id, e.type, e.state_key)


if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument(
'json',
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin
)

args = parser.parse_args()


+ 18
- 14
scripts-dev/check_event_hash.py Wyświetl plik

@@ -1,10 +1,15 @@
from synapse.crypto.event_signing import *
from unpaddedbase64 import encode_base64

import argparse
import hashlib
import sys
import json
import logging
import sys

from unpaddedbase64 import encode_base64

from synapse.crypto.event_signing import (
check_event_content_hash,
compute_event_reference_hash,
)


class dictobj(dict):
@@ -24,27 +29,26 @@ class dictobj(dict):

def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
default=sys.stdin)
parser.add_argument(
"input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)
args = parser.parse_args()
logging.basicConfig()

event_json = dictobj(json.load(args.input_json))

algorithms = {
"sha256": hashlib.sha256,
}
algorithms = {"sha256": hashlib.sha256}

for alg_name in event_json.hashes:
if check_event_content_hash(event_json, algorithms[alg_name]):
print "PASS content hash %s" % (alg_name,)
print("PASS content hash %s" % (alg_name,))
else:
print "FAIL content hash %s" % (alg_name,)
print("FAIL content hash %s" % (alg_name,))

for algorithm in algorithms.values():
name, h_bytes = compute_event_reference_hash(event_json, algorithm)
print "Reference hash %s: %s" % (name, encode_base64(h_bytes))
print("Reference hash %s: %s" % (name, encode_base64(h_bytes)))

if __name__=="__main__":
main()

if __name__ == "__main__":
main()

+ 19
- 17
scripts-dev/check_signature.py Wyświetl plik

@@ -1,15 +1,15 @@

from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes, write_signing_keys
from unpaddedbase64 import decode_base64

import urllib2
import argparse
import json
import logging
import sys
import urllib2

import dns.resolver
import pprint
import argparse
import logging
from signedjson.key import decode_verify_key_bytes, write_signing_keys
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64


def get_targets(server_name):
if ":" in server_name:
@@ -23,6 +23,7 @@ def get_targets(server_name):
except dns.resolver.NXDOMAIN:
yield (server_name, 8448)


def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port)
keys = json.load(urllib2.urlopen(url))
@@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port):
verify_keys[key_id] = verify_key
return verify_keys


def main():

parser = argparse.ArgumentParser()
parser.add_argument("signature_name")
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
default=sys.stdin)
parser.add_argument(
"input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)

args = parser.parse_args()
logging.basicConfig()
@@ -48,24 +51,23 @@ def main():
for target, port in get_targets(server_name):
try:
keys = get_server_keys(server_name, target, port)
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port)
print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port))
write_signing_keys(sys.stdout, keys.values())
break
except:
except Exception:
logging.exception("Error talking to %s:%s", target, port)

json_to_check = json.load(args.input_json)
print "Checking JSON:"
print("Checking JSON:")
for key_id in json_to_check["signatures"][args.signature_name]:
try:
key = keys[key_id]
verify_signed_json(json_to_check, args.signature_name, key)
print "PASS %s" % (key_id,)
except:
print("PASS %s" % (key_id,))
except Exception:
logging.exception("Check for key %s failed" % (key_id,))
print "FAIL %s" % (key_id,)
print("FAIL %s" % (key_id,))


if __name__ == '__main__':
main()


+ 22
- 18
scripts-dev/convert_server_keys.py Wyświetl plik

@@ -1,13 +1,21 @@
import psycopg2
import yaml
import sys
import hashlib
import json
import sys
import time
import hashlib
from unpaddedbase64 import encode_base64

import six

import psycopg2
import yaml
from canonicaljson import encode_canonical_json
from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from canonicaljson import encode_canonical_json
from unpaddedbase64 import encode_base64

if six.PY2:
db_type = six.moves.builtins.buffer
else:
db_type = memoryview


def select_v1_keys(connection):
@@ -39,7 +47,9 @@ def select_v2_json(connection):
cursor.close()
results = {}
for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
results.setdefault(server_name, {})[key_id] = json.loads(
str(key_json).decode("utf-8")
)
return results


@@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return {
"old_verify_keys": {},
"server_name": server_name,
"verify_keys": {
key_id: {"key": key}
for key_id, key in keys.items()
},
"verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
"valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)],
}
@@ -65,7 +72,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))


def main():
@@ -87,7 +94,7 @@ def main():

result = {}
for server in keys:
if not server in json:
if server not in json:
v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server]
)
@@ -96,10 +103,7 @@ def main():

yaml.safe_dump(result, sys.stdout, default_flow_style=False)

rows = list(
row for server, json in result.items()
for row in rows_v2(server, json)
)
rows = list(row for server, json in result.items() for row in rows_v2(server, json))

cursor = connection.cursor()
cursor.executemany(
@@ -107,7 +111,7 @@ def main():
" server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)",
rows
rows,
)
connection.commit()



+ 34
- 20
scripts-dev/definitions.py Wyświetl plik

@@ -1,8 +1,16 @@
#! /usr/bin/python

from __future__ import print_function

import argparse
import ast
import os
import re
import sys

import yaml


class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
super(DefinitionVisitor, self).__init__()
@@ -42,15 +50,18 @@ def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {}
if functions: result['def'] = functions
if classes: result['class'] = classes
if functions:
result['def'] = functions
if classes:
result['class'] = classes
names = defs['names']
uses = []
for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name)
uses.extend(defs['attrs'])
if uses: result['uses'] = uses
if uses:
result['uses'] = uses
result['names'] = names
result['attrs'] = defs['attrs']
return result
@@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names):


if __name__ == '__main__':
import sys, os, argparse, re

parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument(
@@ -105,24 +115,28 @@ if __name__ == '__main__':
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
)
parser.add_argument(
"--pattern", action="append", metavar="REGEXP",
help="Search for a pattern"
"--pattern", action="append", metavar="REGEXP", help="Search for a pattern"
)
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
"directories",
nargs='+',
metavar="DIR",
help="Directories to search for definitions",
)
parser.add_argument(
"--referrers", default=0, type=int,
help="Include referrers up to the given depth"
"--referrers",
default=0,
type=int,
help="Include referrers up to the given depth",
)
parser.add_argument(
"--referred", default=0, type=int,
help="Include referred down to the given depth"
"--referred",
default=0,
type=int,
help="Include referred down to the given depth",
)
parser.add_argument(
"--format", default="yaml",
help="Output format, one of 'yaml' or 'dot'"
"--format", default="yaml", help="Output format, one of 'yaml' or 'dot'"
)
args = parser.parse_args()

@@ -162,7 +176,7 @@ if __name__ == '__main__':
for used_by in entry.get("used", ()):
referrers.add(used_by)
for name, definition in names.items():
if not name in referrers:
if name not in referrers:
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
@@ -176,7 +190,7 @@ if __name__ == '__main__':
for uses in entry.get("uses", ()):
referred.add(uses)
for name, definition in names.items():
if not name in referred:
if name not in referred:
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
@@ -185,12 +199,12 @@ if __name__ == '__main__':
if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot':
print "digraph {"
print("digraph {")
for name, entry in result.items():
print name
print(name)
for used_by in entry.get("used", ()):
if used_by in result:
print used_by, "->", name
print "}"
print(used_by, "->", name)
print("}")
else:
raise ValueError("Unknown format %r" % (args.format))

+ 8
- 5
scripts-dev/dump_macaroon.py Wyświetl plik

@@ -1,8 +1,11 @@
#!/usr/bin/env python2

import pymacaroons
from __future__ import print_function

import sys

import pymacaroons

if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1)
@@ -11,14 +14,14 @@ macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None

macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect()
print(macaroon.inspect())

print ""
print("")

verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True)
try:
verifier.verify(macaroon, key)
print "Signature is correct"
print("Signature is correct")
except Exception as e:
print str(e)
print(str(e))

+ 48
- 51
scripts-dev/federation_client.py Wyświetl plik

@@ -18,21 +18,21 @@
from __future__ import print_function

import argparse
import base64
import json
import sys
from urlparse import urlparse, urlunparse

import nacl.signing
import json
import base64
import requests
import sys

from requests.adapters import HTTPAdapter
import srvlookup
import yaml
from requests.adapters import HTTPAdapter

# uncomment the following to enable debug logging of http requests
#from httplib import HTTPConnection
#HTTPConnection.debuglevel = 1
# from httplib import HTTPConnection
# HTTPConnection.debuglevel = 1


def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding."""
@@ -58,15 +58,15 @@ def decode_base64(input_string):

def encode_canonical_json(value):
return json.dumps(
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
ensure_ascii=False,
# Remove unecessary white space.
separators=(',',':'),
# Sort the keys of dictionaries.
sort_keys=True,
# Encode the resulting unicode as UTF-8 bytes.
).encode("UTF-8")
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
ensure_ascii=False,
# Remove unecessary white space.
separators=(',', ':'),
# Sort the keys of dictionaries.
sort_keys=True,
# Encode the resulting unicode as UTF-8 bytes.
).encode("UTF-8")


def sign_json(json_object, signing_key, signing_name):
@@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name):

NACL_ED25519 = "ed25519"


def decode_signing_key_base64(algorithm, version, key_base64):
"""Decode a base64 encoded signing key
Args:
@@ -143,14 +144,12 @@ def request_json(method, origin_name, origin_key, destination, path, content):
authorization_headers = []

for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig,
)
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig)
authorization_headers.append(bytes(header))
print ("Authorization: %s" % header, file=sys.stderr)
print("Authorization: %s" % header, file=sys.stderr)

dest = "matrix://%s%s" % (destination, path)
print ("Requesting %s" % dest, file=sys.stderr)
print("Requesting %s" % dest, file=sys.stderr)

s = requests.Session()
s.mount("matrix://", MatrixConnectionAdapter())
@@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content):
result = s.request(
method=method,
url=dest,
headers={
"Host": destination,
"Authorization": authorization_headers[0]
},
headers={"Host": destination, "Authorization": authorization_headers[0]},
verify=False,
data=content,
)
@@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content):

def main():
parser = argparse.ArgumentParser(
description=
"Signs and sends a federation request to a matrix homeserver",
description="Signs and sends a federation request to a matrix homeserver"
)

parser.add_argument(
"-N", "--server-name",
"-N",
"--server-name",
help="Name to give as the local homeserver. If unspecified, will be "
"read from the config file.",
"read from the config file.",
)

parser.add_argument(
"-k", "--signing-key-path",
"-k",
"--signing-key-path",
help="Path to the file containing the private ed25519 key to sign the "
"request with.",
"request with.",
)

parser.add_argument(
"-c", "--config",
"-c",
"--config",
default="homeserver.yaml",
help="Path to server config file. Ignored if --server-name and "
"--signing-key-path are both given.",
"--signing-key-path are both given.",
)

parser.add_argument(
"-d", "--destination",
"-d",
"--destination",
default="matrix.org",
help="name of the remote homeserver. We will do SRV lookups and "
"connect appropriately.",
"connect appropriately.",
)

parser.add_argument(
"-X", "--method",
"-X",
"--method",
help="HTTP method to use for the request. Defaults to GET if --data is"
"unspecified, POST if it is."
"unspecified, POST if it is.",
)

parser.add_argument(
"--body",
help="Data to send as the body of the HTTP request"
)
parser.add_argument("--body", help="Data to send as the body of the HTTP request")

parser.add_argument(
"path",
help="request path. We will add '/_matrix/federation/v1/' to this."
"path", help="request path. We will add '/_matrix/federation/v1/' to this."
)

args = parser.parse_args()
@@ -227,13 +223,15 @@ def main():

result = request_json(
args.method,
args.server_name, key, args.destination,
args.server_name,
key,
args.destination,
"/_matrix/federation/v1/" + args.path,
content=args.body,
)

json.dump(result, sys.stdout)
print ("")
print("")


def read_args_from_config(args):
@@ -253,7 +251,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return s, 8448

if ":" in s:
out = s.rsplit(":",1)
out = s.rsplit(":", 1)
try:
port = int(out[1])
except ValueError:
@@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
try:
srv = srvlookup.lookup("matrix", "tcp", s)[0]
return srv.host, srv.port
except:
except Exception:
return s, 8448

def get_connection(self, url, proxies=None):
@@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
(host, port) = self.lookup(parsed.netloc)
netloc = "%s:%d" % (host, port)
print("Connecting to %s" % (netloc,), file=sys.stderr)
url = urlunparse((
"https", netloc, parsed.path, parsed.params, parsed.query,
parsed.fragment,
))
url = urlunparse(
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
)
return super(MatrixConnectionAdapter, self).get_connection(url, proxies)




+ 38
- 24
scripts-dev/hash_history.py Wyświetl plik

@@ -1,23 +1,31 @@
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore
from synapse.storage._base import SQLBaseStore
from synapse.federation.units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from unpaddedbase64 import encode_base64, decode_base64
from canonicaljson import encode_canonical_json
from __future__ import print_function

import sqlite3
import sys

from unpaddedbase64 import decode_base64, encode_base64

from synapse.crypto.event_signing import (
add_event_pdu_content_hash,
compute_pdu_event_reference_hash,
)
from synapse.federation.units import Pdu
from synapse.storage._base import SQLBaseStore
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore


class Store(object):
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__[
"_get_pdu_origin_signatures_txn"
]
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__[
"_store_pdu_reference_hash_txn"
]
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]

@@ -26,9 +34,7 @@ store = Store()


def select_pdus(cursor):
cursor.execute(
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
)
cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC")

ids = cursor.fetchall()

@@ -41,23 +47,30 @@ def select_pdus(cursor):
for pdu in pdus:
try:
if pdu.prev_pdus:
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
for pdu_id, origin, hashes in pdu.prev_pdus:
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
hashes[ref_alg] = encode_base64(ref_hsh)
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
store._store_prev_pdu_hash_txn(
cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh
)
print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
pdu = add_event_pdu_content_hash(pdu)
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(
cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh
)

for alg, hsh_base64 in pdu.hashes.items():
print alg, hsh_base64
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
print(alg, hsh_base64)
store._store_pdu_content_hash_txn(
cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)
)

except Exception:
print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus)

except:
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus

def main():
conn = sqlite3.connect(sys.argv[1])
@@ -65,5 +78,6 @@ def main():
select_pdus(cursor)
conn.commit()

if __name__=='__main__':

if __name__ == '__main__':
main()

+ 8
- 8
scripts-dev/list_url_patterns.py Wyświetl plik

@@ -1,18 +1,17 @@
#! /usr/bin/python

import ast
import argparse
import ast
import os
import sys

import yaml

PATTERNS_V1 = []
PATTERNS_V2 = []

RESULT = {
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2}


class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node):
@@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor):
else:
return


if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
@@ -42,8 +40,10 @@ def find_patterns_in_file(filepath):
parser = argparse.ArgumentParser(description='Find url patterns.')

parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
"directories",
nargs='+',
metavar="DIR",
help="Directories to search for definitions",
)

args = parser.parse_args()


+ 12
- 12
scripts-dev/tail-synapse.py Wyświetl plik

@@ -1,8 +1,9 @@
import requests
import collections
import json
import sys
import time
import json

import requests

Entry = collections.namedtuple("Entry", "name position rows")

@@ -30,11 +31,11 @@ def parse_response(content):


def replicate(server, streams):
return parse_response(requests.get(
server + "/_synapse/replication",
verify=False,
params=streams
).content)
return parse_response(
requests.get(
server + "/_synapse/replication", verify=False, params=streams
).content
)


def main():
@@ -45,7 +46,7 @@ def main():
try:
streams = {
row.name: row.position
for row in replicate(server, {"streams":"-1"})["streams"].rows
for row in replicate(server, {"streams": "-1"})["streams"].rows
}
except requests.exceptions.ConnectionError as e:
time.sleep(0.1)
@@ -53,8 +54,8 @@ def main():
while True:
try:
results = replicate(server, streams)
except:
sys.stdout.write("connection_lost("+ repr(streams) + ")\n")
except Exception:
sys.stdout.write("connection_lost(" + repr(streams) + ")\n")
break
for update in results.values():
for row in update.rows:
@@ -62,6 +63,5 @@ def main():
streams[update.name] = update.position



if __name__=='__main__':
if __name__ == '__main__':
main()

+ 1
- 4
scripts/hash_password Wyświetl plik

@@ -1,12 +1,10 @@
#!/usr/bin/env python

import argparse
import getpass
import sys

import bcrypt
import getpass

import yaml

bcrypt_rounds=12
@@ -52,4 +50,3 @@ if __name__ == "__main__":
password = prompt_for_pass()

print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))


+ 12
- 24
scripts/move_remote_media_to_new_store.py Wyświetl plik

@@ -36,12 +36,9 @@ from __future__ import print_function

import argparse
import logging

import sys

import os

import shutil
import sys

from synapse.rest.media.v1.filepath import MediaFilePaths

@@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
if not os.path.exists(original_file):
logger.warn(
"Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file,
origin_server,
file_id,
original_file,
)
else:
mkdir_and_move(
original_file,
dest_paths.remote_media_filepath(origin_server, file_id),
original_file, dest_paths.remote_media_filepath(origin_server, file_id)
)

# now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
origin_server, file_id,
)
original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id)
if not os.path.exists(original_thumb_dir):
return

mkdir_and_move(
original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
dest_paths.remote_media_thumbnail_dir(origin_server, file_id),
)


@@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file):

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("-v", action='store_true', help='enable debug logging')
parser.add_argument("src_repo", help="Path to source content repo")
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args()

logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)



+ 44
- 32
scripts/register_new_matrix_user Wyświetl plik

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import getpass
@@ -22,19 +23,23 @@ import hmac
import json
import sys
import urllib2

from six import input

import yaml


def request_registration(user, password, server_location, shared_secret, admin=False):
req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)

try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl

f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
@@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F
f.close()
nonce = json.loads(body)["nonce"]
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
print("ERROR! Received %d %s" % (e.code, e.reason))
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
print(resp["error"])
sys.exit(1)

mac = hmac.new(
key=shared_secret,
digestmod=hashlib.sha1,
)
mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1)

mac.update(nonce)
mac.update("\x00")
@@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F

server_location = server_location.rstrip("/")

print "Sending registration request..."
print("Sending registration request...")

req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl

f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
f.read()
f.close()
print "Success."
print("Success.")
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
print("ERROR! Received %d %s" % (e.code, e.reason))
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
print(resp["error"])
sys.exit(1)


@@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
except:
except Exception:
default_user = None

if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,))
user = input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = raw_input("New user localpart: ")
user = input("New user localpart: ")

if not user:
print "Invalid user name"
print("Invalid user name")
sys.exit(1)

if not password:
password = getpass.getpass("Password: ")

if not password:
print "Password cannot be blank."
print("Password cannot be blank.")
sys.exit(1)

confirm_password = getpass.getpass("Confirm password: ")

if password != confirm_password:
print "Passwords do not match"
print("Passwords do not match")
sys.exit(1)

if admin is None:
admin = raw_input("Make admin [no]: ")
admin = input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
@@ -146,42 +149,51 @@ def register_new_user(user, password, server_location, shared_secret, admin):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set.",
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set."
)
parser.add_argument(
"-u", "--user",
"-u",
"--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p", "--password",
"-p",
"--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a", "--admin",
"-a",
"--admin",
action="store_true",
help="Register new user as an admin. Will prompt if --no-admin is not set either.",
help=(
"Register new user as an admin. "
"Will prompt if --no-admin is not set either."
),
)
admin_group.add_argument(
"--no-admin",
action="store_true",
help="Register new user as a regular user. Will prompt if --admin is not set either.",
help=(
"Register new user as a regular user. "
"Will prompt if --admin is not set either."
),
)

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c", "--config",
"-c",
"--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)

group.add_argument(
"-k", "--shared-secret",
help="Shared secret as defined in server config file.",
"-k", "--shared-secret", help="Shared secret as defined in server config file."
)

parser.add_argument(
@@ -189,7 +201,7 @@ if __name__ == "__main__":
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
" 'https://localhost:8448'.",
)

args = parser.parse_args()
@@ -198,7 +210,7 @@ if __name__ == "__main__":
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print "No 'registration_shared_secret' defined in config."
print("No 'registration_shared_secret' defined in config.")
sys.exit(1)
else:
secret = args.shared_secret


+ 122
- 150
scripts/synapse_port_db Wyświetl plik

@@ -15,23 +15,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, reactor
from twisted.enterprise import adbapi

from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

import argparse
import curses
import logging
import sys
import time
import traceback
import yaml

from six import string_types

import yaml

from twisted.enterprise import adbapi
from twisted.internet import defer, reactor

from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

logger = logging.getLogger("synapse_port_db")

@@ -105,6 +105,7 @@ class Store(object):

*All* database interactions should go through this object.
"""

def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
@@ -135,7 +136,8 @@ class Store(object):
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, [], []),
*args, **kwargs
*args,
**kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
@@ -158,22 +160,20 @@ class Store(object):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()

return self.runInteraction("execute_sql", r)

def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
", ".join("%s" for _ in headers),
)

try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
except Exception:
logger.exception("Failed to insert: %s", table)
raise


@@ -206,7 +206,7 @@ class Porter(object):
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
},
)

forward_chunk = 1
@@ -221,10 +221,10 @@ class Porter(object):
table, forward_chunk, backward_chunk
)
else:

def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))

@@ -232,11 +232,7 @@ class Porter(object):

yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
)

forward_chunk = 1
@@ -251,12 +247,16 @@ class Porter(object):
)

@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table, postgres_size, table_size,
backward_chunk+1, forward_chunk-1,
table,
postgres_size,
table_size,
backward_chunk + 1,
forward_chunk - 1,
)

if not table_size:
@@ -271,7 +271,9 @@ class Porter(object):
return

if table in (
"user_directory", "user_directory_search", "users_who_share_rooms",
"user_directory",
"user_directory_search",
"users_who_share_rooms",
"users_in_pubic_room",
):
# We don't port these tables, as they're a faff and we can regenreate
@@ -283,37 +285,35 @@ class Porter(object):
# We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there.
yield self.postgres_store._simple_insert(
table=table,
values={"stream_id": None},
table=table, values={"stream_id": None}
)
self.progress.update(table, table_size) # Mark table as done
return

forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
)

backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
)

do_forward = [True]
do_backward = [True]

while True:

def r(txn):
forward_rows = []
backward_rows = []
if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
txn.execute(forward_select, (forward_chunk, self.batch_size))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False

if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
txn.execute(backward_select, (backward_chunk, self.batch_size))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
@@ -325,9 +325,7 @@ class Porter(object):

return headers, forward_rows, backward_rows

headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)

if frows or brows:
if frows:
@@ -339,9 +337,7 @@ class Porter(object):
rows = self._convert_rows(table, headers, rows)

def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)

self.postgres_store._simple_update_one_txn(
txn,
@@ -362,8 +358,9 @@ class Porter(object):
return

@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk
):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@@ -373,8 +370,9 @@ class Porter(object):
)

while True:

def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,))
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]

@@ -402,18 +400,21 @@ class Porter(object):
else:
rows_dict.append(d)

txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
])
txn.executemany(
sql,
[
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
],
)

self.postgres_store._simple_update_one_txn(
txn,
@@ -437,7 +438,8 @@ class Porter(object):
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
k: v
for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
@@ -450,13 +452,11 @@ class Porter(object):
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
self.sqlite_config["name"], **self.sqlite_config["args"]
)

postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
self.postgres_config["name"], **self.postgres_config["args"]
)

sqlite_engine = create_engine(sqlite_config)
@@ -465,9 +465,7 @@ class Porter(object):
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)

yield self.postgres_store.execute(
postgres_engine.check_database
)
yield self.postgres_store.execute(postgres_engine.check_database)

# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
@@ -477,6 +475,7 @@ class Porter(object):
self.setup_db(postgres_config, postgres_engine)

self.progress.set_state("Creating port tables")

def create_port_table(txn):
txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
@@ -501,9 +500,7 @@ class Porter(object):
)

try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
yield self.postgres_store.runInteraction("alter_table", alter_table)
except Exception as e:
pass

@@ -514,11 +511,7 @@ class Porter(object):
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
)

postgres_tables = yield self.postgres_store._simple_select_onecol(
@@ -545,18 +538,14 @@ class Porter(object):
# Step 4. Do the copying.
self.progress.set_state("Copying to postgres")
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
[self.handle_table(*res) for res in setup_res], consumeErrors=True
)

# Step 5. Do final post-processing
yield self._setup_state_group_id_seq()

self.progress.done()
except:
except Exception:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
@@ -566,9 +555,7 @@ class Porter(object):
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])

bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]

class BadValueException(Exception):
pass
@@ -577,18 +564,21 @@ class Porter(object):
if j in bool_cols:
return bool(col)
elif isinstance(col, string_types) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
raise BadValueException();
logger.warn(
"DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
col,
)
raise BadValueException()
return col

outrows = []
for i, row in enumerate(rows):
try:
outrows.append(tuple(
conv(j, col)
for j, col in enumerate(row)
if j > 0
))
outrows.append(
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
)
except BadValueException:
pass

@@ -616,9 +606,7 @@ class Porter(object):

return headers, [r for r in rows if r[ts_ind] < yesterday]

headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
headers, rows = yield self.sqlite_store.runInteraction("select", r)

rows = self._convert_rows("sent_transactions", headers, rows)

@@ -639,7 +627,7 @@ class Porter(object):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
(yesterday,),
)

rows = txn.fetchall()
@@ -657,21 +645,17 @@ class Porter(object):
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
},
)

def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
size, = txn.fetchone()
return int(size)

remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
remaining_count = yield self.sqlite_store.execute(get_sent_table_size)

total_count = remaining_count + inserted_rows

@@ -680,13 +664,11 @@ class Porter(object):
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk,
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
)

brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
)

defer.returnValue(frows[0][0] + brows[0][0])
@@ -694,7 +676,7 @@ class Porter(object):
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
"SELECT count(*) FROM %s" % (table,)
)

defer.returnValue(rows[0][0])
@@ -717,22 +699,21 @@ class Porter(object):
def _setup_state_group_id_seq(self):
def r(txn):
txn.execute("SELECT MAX(id) FROM state_groups")
next_id = txn.fetchone()[0]+1
txn.execute(
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
(next_id,),
)
next_id = txn.fetchone()[0] + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))

return self.postgres_store.runInteraction("setup_state_group_id_seq", r)


##############################################
###### The following is simply UI stuff ######
# The following is simply UI stuff
##############################################


class Progress(object):
"""Used to report progress of the port
"""

def __init__(self):
self.tables = {}

@@ -758,6 +739,7 @@ class Progress(object):
class CursesProgress(Progress):
"""Reports progress to a curses window
"""

def __init__(self, stdscr):
self.stdscr = stdscr

@@ -801,7 +783,7 @@ class CursesProgress(Progress):
duration = int(now) - int(self.start_time)

minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
duration_str = '%02dm %02ds' % (minutes, seconds)

if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
@@ -814,16 +796,12 @@ class CursesProgress(Progress):
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
status = "Time spent: %s (est. remaining: %s)" % (
duration_str,
est_remaining_str,
)

self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
self.stdscr.addstr(0, 0, status, curses.A_BOLD)

max_len = max([len(t) for t in self.tables.keys()])

@@ -831,9 +809,7 @@ class CursesProgress(Progress):
middle_space = 1

items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
items.sort(key=lambda i: (i[1]["perc"], i[0]))

for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@@ -844,9 +820,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)

self.stdscr.addstr(
i + 2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
)

size = 20
@@ -857,15 +831,13 @@ class CursesProgress(Progress):
)

self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space,
i + 2,
left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)

if self.finished:
self.stdscr.addstr(
rows - 1, 0,
"Press any key to exit...",
)
self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")

self.stdscr.refresh()
self.last_update = time.time()
@@ -877,29 +849,25 @@ class CursesProgress(Progress):

def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh()


class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""

def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)

data = self.tables[table]

print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
print(
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
)

def set_state(self, state):
print state + "..."
print(state + "...")


##############################################
@@ -909,34 +877,38 @@ class TerminalProgress(Progress):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument(
"--sqlite-database", required=True,
"--sqlite-database",
required=True,
help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server"
" currently used by a running synapse server",
)
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True,
help="The database config file for the PostgreSQL database"
"--postgres-config",
type=argparse.FileType('r'),
required=True,
help="The database config file for the PostgreSQL database",
)
parser.add_argument(
"--curses", action='store_true',
help="display a curses based progress UI"
"--curses", action='store_true', help="display a curses based progress UI"
)

parser.add_argument(
"--batch-size", type=int, default=1000,
"--batch-size",
type=int,
default=1000,
help="The number of rows to select from the SQLite table each"
" iteration [default=1000]",
" iteration [default=1000]",
)

args = parser.parse_args()

logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}

if args.curses:


+ 1
- 1
synapse/app/homeserver.py Wyświetl plik

@@ -568,7 +568,7 @@ def run(hs):
clock.call_later(5 * 60, start_phone_stats_home)

if hs.config.daemonize and hs.config.print_pidfile:
print (hs.config.pid_file)
print(hs.config.pid_file)

_base.start_reactor(
"synapse-homeserver",


+ 1
- 1
synapse/config/__main__.py Wyświetl plik

@@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)

print (getattr(config, key))
print(getattr(config, key))
sys.exit(0)
else:
sys.stderr.write("Unknown command %r\n" % (action,))


+ 59
- 62
synapse/config/_base.py Wyświetl plik

@@ -106,10 +106,7 @@ class Config(object):
@classmethod
def check_file(cls, file_path, config_name):
if file_path is None:
raise ConfigError(
"Missing config for %s."
% (config_name,)
)
raise ConfigError("Missing config for %s." % (config_name,))
try:
os.stat(file_path)
except OSError as e:
@@ -128,9 +125,7 @@ class Config(object):
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(dir_path):
raise ConfigError(
"%s is not a directory" % (dir_path,)
)
raise ConfigError("%s is not a directory" % (dir_path,))
return dir_path

@classmethod
@@ -156,21 +151,20 @@ class Config(object):
return results

def generate_config(
self,
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
self, config_dir_path, server_name, is_generating_file, report_stats=None
):
default_config = "# vim:ft=yaml\n"

default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats,
))
default_config += "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats,
)
)

config = yaml.load(default_config)

@@ -178,23 +172,22 @@ class Config(object):

@classmethod
def load_config(cls, description, argv):
config_parser = argparse.ArgumentParser(
description=description,
)
config_parser = argparse.ArgumentParser(description=description)
config_parser.add_argument(
"-c", "--config-path",
"-c",
"--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
" may specify directories containing *.yaml files.",
)

config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
" their location is given explicitly in the config."
" Defaults to the directory containing the last config file",
" their location is given explicitly in the config."
" Defaults to the directory containing the last config file",
)

config_args = config_parser.parse_args(argv)
@@ -203,9 +196,7 @@ class Config(object):

obj = cls()
obj.read_config_files(
config_files,
keys_directory=config_args.keys_directory,
generate_keys=False,
config_files, keys_directory=config_args.keys_directory, generate_keys=False
)
return obj

@@ -213,38 +204,38 @@ class Config(object):
def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
"-c",
"--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
" may specify directories containing *.yaml files.",
)
config_parser.add_argument(
"--generate-config",
action="store_true",
help="Generate a config file for the server name"
help="Generate a config file for the server name",
)
config_parser.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"]
choices=["yes", "no"],
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
help="Generate any missing key files then exit",
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly"
" specified in the config."
" certs and signing keys should be stored in, unless explicitly"
" specified in the config.",
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
"-H", "--server-name", help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)

@@ -257,8 +248,8 @@ class Config(object):
if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
MISSING_REPORT_STATS_SPIEL
"Please specify either --report-stats=yes or --report-stats=no\n\n"
+ MISSING_REPORT_STATS_SPIEL
)
if not config_files:
config_parser.error(
@@ -287,26 +278,32 @@ class Config(object):
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
is_generating_file=True,
)
obj.invoke_all("generate_files", config)
config_file.write(config_str)
print((
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
) % (config_path, server_name))
print(
(
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
)
% (config_path, server_name)
)
print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
print((
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,))
print(
(
"Config file %r already exists. Generating any missing key"
" files."
)
% (config_path,)
)
generate_keys = True

parser = argparse.ArgumentParser(
@@ -338,8 +335,7 @@ class Config(object):

return obj

def read_config_files(self, config_files, keys_directory=None,
generate_keys=False):
def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])

@@ -364,8 +360,9 @@ class Config(object):

if "report_stats" not in config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
+ "\n"
+ MISSING_REPORT_STATS_SPIEL
)

if generate_keys:
@@ -399,16 +396,16 @@ def find_config_files(search_paths):
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
err = "Found subdirectory in config directory: %r. IGNORING."
print(err % (entry_path,))
continue

if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
err = (
"Found file in config directory that does not end in "
"'.yaml': %r. IGNORING."
)
print(err % (entry_path,))
continue

files.append(entry_path)


+ 2
- 2
synapse/storage/_base.py Wyświetl plik

@@ -18,7 +18,7 @@ import threading
import time

from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
from six.moves import builtins, intern, range

from canonicaljson import json
from prometheus_client import Histogram
@@ -1233,7 +1233,7 @@ def db_to_json(db_content):

# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content)

# Decode it to a Unicode string before feeding it to json.loads, so we


+ 1
- 1
synapse/storage/keys.py Wyświetl plik

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = buffer
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview



+ 1
- 1
synapse/storage/pusher.py Wyświetl plik

@@ -29,7 +29,7 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)

if six.PY2:
db_binary_type = buffer
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview



+ 1
- 1
synapse/storage/signatures.py Wyświetl plik

@@ -27,7 +27,7 @@ from ._base import SQLBaseStore
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = buffer
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview



+ 1
- 1
synapse/storage/transactions.py Wyświetl plik

@@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = buffer
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview



+ 3
- 1
synapse/util/caches/stream_change_cache.py Wyświetl plik

@@ -15,6 +15,8 @@

import logging

from six import integer_types

from sortedcontainers import SortedDict

from synapse.util import caches
@@ -47,7 +49,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
assert type(stream_pos) is int or type(stream_pos) is long
assert type(stream_pos) in integer_types

if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()


+ 36
- 44
synctl Wyświetl plik

@@ -76,8 +76,7 @@ def start(configfile):

try:
subprocess.check_call(args)
write("started synapse.app.homeserver(%r)" %
(configfile,), colour=GREEN)
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting (exit code: %d); see above for logs" % e.returncode,
@@ -86,21 +85,15 @@ def start(configfile):


def start_worker(app, configfile, worker_configfile):
args = [
sys.executable, "-B",
"-m", app,
"-c", configfile,
"-c", worker_configfile
]
args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile]

try:
subprocess.check_call(args)
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting %s(%r) (exit code: %d); see above for logs" % (
app, worker_configfile, e.returncode,
),
"error starting %s(%r) (exit code: %d); see above for logs"
% (app, worker_configfile, e.returncode),
colour=RED,
)

@@ -120,9 +113,9 @@ def stop(pidfile, app):
abort("Cannot stop %s: Unknown error" % (app,))


Worker = collections.namedtuple("Worker", [
"app", "configfile", "pidfile", "cache_factor", "cache_factors",
])
Worker = collections.namedtuple(
"Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"]
)


def main():
@@ -141,24 +134,20 @@ def main():
help="the homeserver config file, defaults to homeserver.yaml",
)
parser.add_argument(
"-w", "--worker",
metavar="WORKERCONFIG",
help="start or stop a single worker",
"-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker"
)
parser.add_argument(
"-a", "--all-processes",
"-a",
"--all-processes",
metavar="WORKERCONFIGDIR",
help="start or stop all the workers in the given directory"
" and the main synapse process",
" and the main synapse process",
)

options = parser.parse_args()

if options.worker and options.all_processes:
write(
'Cannot use "--worker" with "--all-processes"',
stream=sys.stderr
)
write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr)
sys.exit(1)

configfile = options.configfile
@@ -167,9 +156,7 @@ def main():
write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), options.configfile
),
" --server-name=<server name>'\n" % (" ".join(SYNAPSE), options.configfile),
stream=sys.stderr,
)
sys.exit(1)
@@ -194,8 +181,7 @@ def main():
worker_configfile = options.worker
if not os.path.exists(worker_configfile):
write(
"No worker config found at %r" % (worker_configfile,),
stream=sys.stderr,
"No worker config found at %r" % (worker_configfile,), stream=sys.stderr
)
sys.exit(1)
worker_configfiles.append(worker_configfile)
@@ -211,9 +197,9 @@ def main():
stream=sys.stderr,
)
sys.exit(1)
worker_configfiles.extend(sorted(glob.glob(
os.path.join(worker_configdir, "*.yaml")
)))
worker_configfiles.extend(
sorted(glob.glob(os.path.join(worker_configdir, "*.yaml")))
)

workers = []
for worker_configfile in worker_configfiles:
@@ -223,14 +209,12 @@ def main():
if worker_app == "synapse.app.homeserver":
# We need to special case all of this to pick up options that may
# be set in the main config file or in this worker config file.
worker_pidfile = (
worker_config.get("pid_file")
or pidfile
worker_pidfile = worker_config.get("pid_file") or pidfile
worker_cache_factor = (
worker_config.get("synctl_cache_factor") or cache_factor
)
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
worker_cache_factors = (
worker_config.get("synctl_cache_factors")
or cache_factors
worker_config.get("synctl_cache_factors") or cache_factors
)
daemonize = worker_config.get("daemonize") or config.get("daemonize")
assert daemonize, "Main process must have daemonize set to true"
@@ -239,19 +223,27 @@ def main():
for key in worker_config:
if key == "worker_app": # But we allow worker_app
continue
assert not key.startswith("worker_"), \
"Main process cannot use worker_* config"
assert not key.startswith(
"worker_"
), "Main process cannot use worker_* config"
else:
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize")
worker_configfile,
"worker_daemonize",
)
worker_cache_factor = worker_config.get("synctl_cache_factor")
worker_cache_factors = worker_config.get("synctl_cache_factors", {})
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
worker_cache_factors,
))
workers.append(
Worker(
worker_app,
worker_configfile,
worker_pidfile,
worker_cache_factor,
worker_cache_factors,
)
)

action = options.action



+ 2
- 2
tox.ini Wyświetl plik

@@ -108,10 +108,10 @@ commands =

[testenv:pep8]
skip_install = True
basepython = python2.7
basepython = python3.6
deps =
flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"

[testenv:check_isort]
skip_install = True


Ładowanie…
Anuluj
Zapisz