remote: remote config prototype

This commit is contained in:
Shiz 2022-04-04 20:24:28 +02:00
parent aa4c2eb60b
commit 7ee35a990e
6 changed files with 584 additions and 80 deletions

View File

@ -14,6 +14,7 @@ from .core import (
do_discover_interface, do_config_interface, config_interface,
do_sync_interface, sync_interface, sync_all_interfaces,
)
from .remote import WeegeeRemoteServer, REMOTE_COMMAND_SETS
from .extra import WeegeeServer, WeegeeClient
logger = getLogger(__name__)

View File

@ -12,9 +12,11 @@ logging.basicConfig(level=logging.DEBUG)
from .dazy import Item, export_items, import_items
from .wireguard import WireguardHostType, WireguardConfigFormat
from . import (
REMOTE_COMMAND_SETS,
WeegeeContext, WeegeeConfig,
WeegeeHook, WeegeeHost, WeegeePublicInterface, WeegeeInterface, WeegeePeer, WeegeeConnection,
WeegeeServer, WeegeeClient,
WeegeeRemoteServer,
find_interface_other_peers, config_interface,
sync_all_interfaces,
setup,
@ -94,12 +96,12 @@ def main():
transfer = wg_peer.get_total_transfer()
rx = siify(transfer[0]) if transfer else 'nothing'
tx = siify(transfer[1]) if transfer else 'nothing'
handshake = wg_peer.last_handshake()
last_handshake = timestampify(handshake) if handshake else 'never'
endpoint = wg_peer.last_endpoint()
last_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else ''
handshake = wg_peer.latest_handshake()
latest_handshake = timestampify(handshake) if handshake else 'never'
endpoint = wg_peer.latest_endpoint()
latest_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else ''
print(f' {peer.name} ({", ".join(str(x) for x in peer.interface.addresses)}, {wg_peer.name})')
print(f' last handshake {last_handshake}{f" from {last_endpoint}" if last_endpoint else ""}, {tx} sent, {rx} received')
print(f' latest handshake {latest_handshake}{f" from {latest_endpoint}" if latest_endpoint else ""}, {tx} sent, {rx} received')
for wg_peer in wg_peers.values():
print(f' ({wg_peer.name}): found but not expected')
success = False
@ -133,6 +135,41 @@ def main():
import_cmd.add_argument('file', type=argparse.FileType('r', encoding='utf-8'), nargs='?', default=sys.stdin)
import_cmd.set_defaults(func=do_import, parser=import_cmd)
def do_remote(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> None:
if not args.host:
config = ctx.get_config()
args.host = config.default_remote_host or WeegeeHost.LOCAL_HOST_NAME
commands = set()
for c in args.command:
if c.startswith('@'):
c = c[1:]
if c not in REMOTE_COMMAND_SETS:
parser.error(f'unknown command set: @{c}')
commands.update(REMOTE_COMMAND_SETS[c])
else:
commands.add(c)
host = WeegeeHost.load(ctx, args.host)
interfaces = set(args.interface)
addresses = set(args.address)
routes = set(args.route)
server = WeegeeRemoteServer(host.conn.conn,
allowed_commands=commands or None,
allowed_interfaces=interfaces or None,
allowed_addresses=addresses or None,
allowed_routes=routes or None,
)
server.serve(sys.stdin, sys.stdout)
remote = commands.add_parser('remote', help='remote server')
remote.add_argument('-c', '--command', type=str, action='append', default=[], help='allowed command(s) or @-prefixed command set(s)')
remote.add_argument('-i', '--interface', type=str, action='append', default=[], help='allowed interface(s)')
remote.add_argument('-a', '--address', type=ipaddress.ip_interface, action='append', default=[], help='allowed interface address(es)')
remote.add_argument('-r', '--route', type=ipaddress.ip_network, action='append', default=[], help='allowed interface route(s)')
remote.add_argument('host', help='host', nargs='?', default=None)
remote.set_defaults(func=do_remote, parser=remote)
# System commands

View File

@ -228,9 +228,8 @@ class WeegeeHost(WeegeeBase):
'configure',
]
@property
def config_format(self) -> O[WireguardConfigFormat]:
return self.conn.conn.CONFIG_FORMAT
def get_config_format(self, name: str) -> O[WireguardConfigFormat]:
return self.conn.conn.get_config_format(name)
def get_hooks(self, name: str, **kwargs) -> WeegeeHookRunner:
return WeegeeHookRunner([WeegeeHook(self.context, x) for x in getattr(self, f'{name}_hooks')], self.conn.conn, kwargs)
@ -264,18 +263,12 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
host: WeegeeHost
inner: WireguardConnection
def __post_init__(self) -> None:
self.CONFIG_FORMAT = self.inner.CONFIG_FORMAT
def _run(self, *args, **kwargs) -> str:
return self.inner._run(*args, **kwargs)
@staticmethod
def _get_family(obj: U[IPAddress, IPInterface, IPNetwork]) -> str:
return f'ipv{obj.version}'
def has_interface(self, name: str) -> bool:
return self.inner.has_interface(name)
def __getattr__(self, name: str) -> Any:
return getattr(self.inner, name)
def create_interface(self, name: str) -> None:
with self.host.get_hooks('interface_add', i=name):
@ -285,18 +278,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
with self.host.get_hooks('interface_del', i=name):
return self.inner.destroy_interface(name)
def set_mtu(self, name: str, mtu: int) -> None:
return self.inner.set_mtu(name, mtu)
def set_up(self, name: str) -> None:
return self.inner.set_up(name)
def set_down(self, name: str) -> None:
return self.inner.set_down(name)
def get_addresses(self, name: str) -> List[IPInterface]:
return self.inner.get_addresses(name)
def add_address(self, name: str, address: IPInterface) -> None:
family_hook = f'address_{self._get_family(address)}_add'
with self.host.get_hooks('address_add', i=name, a=str(address)), self.host.get_hooks(family_hook, i=name, a=str(address)):
@ -311,9 +292,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
for address in self.get_addresses(name):
self.delete_address(name, address)
def get_routes(self, name: str) -> List[IPNetwork]:
return self.inner.get_routes(name)
def add_route(self, name: str, route: IPNetwork) -> None:
family_hook = f'route_{self._get_family(route)}_add'
with self.host.get_hooks('route_add', i=name, r=str(route)), self.host.get_hooks(family_hook, i=name, r=str(route)):
@ -601,7 +579,7 @@ def do_sync_interface(interface: WeegeeInterface, peers: Set[WeegeePeer], connec
for host in interface.hosts:
if auto and not host.autosync:
continue
config = do_config_interface(interface, host.config_format, peers, connections, host)
config = do_config_interface(interface, host.get_config_format(interface.interface_name), peers, connections, host)
if not host.sync_interface(interface.interface_name, interface.mtu, interface.addresses, routes, config):
raise ValueError(f'host {host.name} failed to sync interface!')
return other_peers

View File

@ -102,8 +102,9 @@ WEEGEE_CONFIG = WeegeeMeta(
name='wg/config',
version=1,
spec=[
f'default_server_hosts: [str] = []',
f'default_client_hosts: [str] = []',
'default_server_hosts: [str] = []',
'default_client_hosts: [str] = []',
'default_remote_host: ?str = ',
'log_level: str = "info"',
'meta_path: str = "."',
'data_path: ?str = ',

269
weegee/remote.py Normal file
View File

@ -0,0 +1,269 @@
from dataclasses import dataclass, field
import json
import enum
import ipaddress
import inspect
from typing import Optional as O, Tuple, Dict, Set, Any
from .wireguard import WireguardConnection, IPInterface, IPNetwork
class WeegeeRemoteErrorCode(enum.Enum):
InvalidCommand = -1
InvalidParameter = -2
NotFound = -3
NotAllowed = -4
Conflict = -5
InternalError = -10
class WeegeeRemoteError(Exception):
code: WeegeeRemoteErrorCode
def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None:
super().__init__(message)
self.code = code
REMOTE_PROTOCOL_VERSION = 0
REMOTE_COMMAND_SETS = {}
REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'}
REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'}
REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | {
'interface.create', 'interface.destroy',
'address.query', 'address.create', 'address.destroy',
'route.query', 'route.create', 'route.destroy',
}
@dataclass
class WeegeeRemoteServer:
conn: WireguardConnection
allowed_commands: O[Set[str]] = None
allowed_interfaces: O[Set[str]] = None
allowed_addresses: O[Set[IPInterface]] = None
allowed_routes: O[Set[IPNetwork]] = None
handlers: Dict[str, Any] = field(init=False)
handler_sigs: Dict[str, inspect.Signature] = field(init=False)
def __post_init__(self) -> None:
self.handlers = {
'interface.query': self.handle_interface_query,
'interface.create': self.handle_interface_create,
'interface.destroy': self.handle_interface_destroy,
'interface.set': self.handle_interface_set,
'interface.info': self.handle_interface_info,
'interface.configure': self.handle_interface_configure,
'address.query': self.handle_address_query,
'address.create': self.handle_address_create,
'address.destroy': self.handle_address_destroy,
'route.query': self.handle_route_query,
'route.create': self.handle_route_create,
'route.destroy': self.handle_route_destroy,
'peer.query': self.handle_peer_query,
'peer.info': self.handle_peer_info,
'key.create': self.handle_key_create,
'key.convert': self.handle_key_convert,
}
self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()}
def _check_interface(self, interface: str) -> None:
if self.allowed_interfaces is not None and interface not in self.allowed_interfaces:
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}')
def _check_existing_interface(self, interface: str) -> None:
self._check_interface(interface)
if not self.conn.has_interface(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}')
def _check_address(self, address: IPInterface) -> None:
if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}')
def _check_route(self, route: IPNetwork) -> None:
if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}')
def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]:
interfaces = self.conn.get_interfaces()
return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]}
def handle_interface_create(self, interface: str) -> Dict[str, Any]:
self._check_interface(interface)
if self.conn.has_interface(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}')
self.conn.create_interface(interface)
return {}
def handle_interface_destroy(self, interface: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
self.conn.destroy_interface(interface)
return {}
def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if up is not None:
if up:
self.conn.set_up(interface)
else:
self.conn.set_down(interface)
if mtu is not None:
self.conn.set_mtu(interface, mtu)
return {}
def handle_interface_info(self, interface: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value}
def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]:
self._check_existing_interface(interface)
if config is not None:
self.conn.set_config(name, config, sync)
return {}
def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if address:
address = ipaddress.ip_interface(address)
return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]}
def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
address = ipaddress.ip_interface(address)
self._check_address(address)
if address in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}')
self.conn.add_address(interface, address)
return {}
def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
address = ipaddress.ip_interface(address)
self._check_address(address)
if address not in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}')
self.conn.delete_address(interface, address)
return {}
def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if route:
route = ipaddress.ip_network(route)
return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]}
def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
route = ipaddress.ip_network(route)
self._check_route(route)
if route in self.conn.get_routes(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}')
self.conn.add_route(interface, route)
return {}
def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
route = ipaddress.ip_network(route)
self._check_route(route)
if route not in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}')
self.conn.delete_route(interface, route)
return {}
def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]}
def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
info = self.conn.get_peer_info(interface, peer)
if not info:
return None
return {
'latest_handshake': int(info.latest_handshake.timestamp()),
'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints],
'allowed_ips': [str(x) for x in info.allowed_ips],
'sent_bytes': info.sent_bytes,
'recv_bytes': info.recv_bytes,
}
def handle_key_create(self, type: str) -> Dict[str, Any]:
if type == 'preshared':
return {'key': self.conn.gen_preshared_key()}
if type == 'private':
return {'key': self.conn.gen_private_key()}
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]:
if type == 'public':
return {'key': self.conn.get_public_key(value)}
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]:
d = json.loads(data)
if d['type'] != 'request':
return None
return d['command'], d['parameters']
def _make_banner(self) -> str:
return json.dumps({
'type': 'banner',
'version': f'weegee/0.1',
'protocol': REMOTE_PROTOCOL_VERSION,
})
def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str:
d = {
'type': 'response',
'command': command,
}
if error:
d['ok'] = False
d['error'] = kwargs
else:
d['ok'] = True
d['values'] = kwargs
return json.dumps(d)
def serve(self, infile, outfile) -> None:
outfile.write(self._make_banner() + '\n')
while True:
line = infile.readline()
if not line:
break
req = self._parse_req(line)
if not req:
break
command, params = req
if command not in self.handlers:
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'}
error = True
elif self.allowed_commands is not None and command not in self.allowed_commands:
data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'}
error = True
else:
signature = self.handler_sigs[command]
try:
signature.bind(**params)
except TypeError as e:
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)}
error = True
else:
try:
data = self.handlers[command](**params)
error = False
except WeegeeRemoteError as e:
data = {'code': e.code.value, 'msg': str(e)}
error = True
except NotImplementedError as e:
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'}
error = True
except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError):
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'}
error = True
except Exception as e:
data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)}
error = True
outfile.write(self._make_resp(command, error, **data) + '\n')

View File

@ -4,6 +4,7 @@ from datetime import datetime
import shlex
import ipaddress
import enum
import json
from typing import Optional as O, Union as U, Type, Tuple, List
import subprocess
from logging import getLogger
@ -16,6 +17,14 @@ IPNetwork = U[ipaddress.IPv4Network, ipaddress.IPv6Network]
IPInterface = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
@dataclass
class WireguardPeerInfo:
latest_handshake: O[datetime]
endpoints: List[Tuple[IPAddress, int]]
allowed_ips: List[IPNetwork]
sent_bytes: int
recv_bytes: int
class WireguardConfigFormat(enum.Enum):
WG = 'wg'
WGQuick = 'wg-quick'
@ -23,13 +32,8 @@ class WireguardConfigFormat(enum.Enum):
class WireguardConnection:
CONFIG_FORMAT: O[WireguardConfigFormat] = None
def _run(self, *args, shell=False, stdin=None, **kwargs) -> str:
raise NotImplementedError
def _run_wg(self, *args, **kwargs) -> str:
return self._run('wg', *args, **kwargs)
def get_interfaces(self) -> O[List[str]]:
return None
def has_interface(self, name: str) -> bool:
raise NotImplementedError
@ -40,6 +44,9 @@ class WireguardConnection:
def destroy_interface(self, name: str) -> None:
raise NotImplementedError
def get_config_format(self, name: str) -> WireguardConfigFormat:
return None
def set_mtu(self, name: str, mtu: int) -> None:
raise NotImplementedError
@ -50,7 +57,7 @@ class WireguardConnection:
raise NotImplementedError
def get_addresses(self, name: str) -> List[IPInterface]:
raise NotADirectoryError
raise NotImplementedError
def add_address(self, name: str, address: IPInterface) -> None:
raise NotImplementedError
@ -75,6 +82,29 @@ class WireguardConnection:
for route in self.get_routes(name):
self.delete_route(name, route)
def gen_preshared_key(self) -> str:
raise NotImplementedError
def gen_private_key(self) -> str:
raise NotImplementedError
def get_public_key(self, private_key: str) -> str:
raise NotImplementedError
def get_peers(self, name: str) -> List[str]:
raise NotImplementedError
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
raise NotImplementedError
def get_config(self, name: str) -> str:
raise NotImplementedError
def set_config(self, name: str, config: str, sync: bool = True) -> None:
raise NotImplementedError
def clear_config(self, name: str) -> None:
return self.set_config(name, '', sync=False)
@dataclass
class WireguardConnectionBase(WireguardConnection):
@ -82,7 +112,7 @@ class WireguardConnectionBase(WireguardConnection):
user: O[str]
elevate_user: O[str]
def _run(self, *args, shell=False, stdin=None, **kwargs) -> str:
def _open(self, *args, shell=False, **kwargs) -> str:
cmd = []
logger.debug(f'#{"[" + self.host + "]" if self.host else ""} {args[0] if shell else shlex.join(args)}')
if self.host:
@ -97,20 +127,27 @@ class WireguardConnectionBase(WireguardConnection):
else:
cmd += list(args)
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='ascii', **kwargs)
def _run(self, *args, stdin=None, **kwargs) -> str:
p = self._open(*args, **kwargs)
stdout, stderr = p.communicate(stdin)
if p.returncode:
raise subprocess.CalledProcessError(p.returncode, cmd, output=stdout, stderr=stderr)
return stdout.strip()
class WireguardLinuxConnection(WireguardConnectionBase):
CONFIG_FORMAT = WireguardConfigFormat.WG
def _run_wg(self, *args, **kwargs) -> str:
return self._run('wg', *args, **kwargs)
class WireguardLinuxConnection(WireguardConnectionBase):
def _run_ip(self, *args) -> str:
return self._run('ip', *args)
def _run_ip46(self, *args) -> str:
return (self._run_ip('-4', *args) + '\n' + self._run_ip('-6', *args)).strip()
def get_interfaces(self) -> List[str]:
return self._run_wg('show', 'interfaces').split()
def has_interface(self, name: str) -> bool:
try:
self._run_ip('link', 'show', name)
@ -118,6 +155,9 @@ class WireguardLinuxConnection(WireguardConnectionBase):
except subprocess.CalledProcessError:
return False
def get_config_format(self, name: str) -> WireguardConfigFormat:
return WireguardConfigFormat.WG
def create_interface(self, name: str) -> None:
self._run_ip('link', 'add', name, 'type', 'wireguard')
@ -163,42 +203,218 @@ class WireguardLinuxConnection(WireguardConnectionBase):
def delete_route(self, name: str, route: IPNetwork) -> None:
self._run_ip('route', 'del', str(route), 'dev', name)
def get_peers(self, name: str) -> List[str]:
return self._run_wg('show', name, 'peers').split()
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
for line in self.conn._run_wg('show', name, 'dump').splitlines():
line_peer, *parts = line.split()
if line_peer != peer:
continue
psk, endpoints_str, allowed_ips_str, latest_handshake, rx, tx, keepalive, *_ = parts
if endpoints_str == '(none)':
endpoints = []
else:
endpoints = [x.split(':', maxsplit=1) for x in endpoints_str.split(',')]
if allowed_ips_str == '(none)':
allowed_ips = []
else:
allowed_ips = allowed_ips_str.split(',')
latest_handshake_ts = int(latest_handshake)
return WireguardPeerInfo(
datetime.fromtimestamp(latest_handshake_ts) if latest_handshake_ts > 0 else None,
[(ipaddress.ip_address(ip.strip('[]')), port) for (ip, port) in endpoints],
[ipaddress.ip_network(x) for x in allowed_ips],
int(tx),
int(rx),
)
return None
def get_config(self, name: str) -> str:
return self._run_wg('getconf', name)
def set_config(self, name: str, config: str, sync: bool = True) -> None:
if sync:
self._run_wg('syncconf', name, '/dev/stdin', stdin=config)
else:
self._run_wg('setconf', name, '/dev/stdin', stdin=config)
def gen_preshared_key(self) -> str:
return self._run_wg('genpsk')
def gen_private_key(self) -> str:
return self._run_wg('genkey')
def get_public_key(self, private_key: str) -> str:
return self._run_wg('pubkey', stdin=private_key)
class WireguardRemoteConnection(WireguardConnectionBase):
process: subprocess.Popen
PROTOCOL_VERSION = 0
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.process = None
def _make_req(self, name: str, /, **kwargs) -> str:
return json.dumps({
'type': 'request',
'command': name,
'parameters': kwargs,
})
def _parse_banner(self, data: str) -> O[str]:
d = json.loads(data)
if d['type'] != 'banner':
return None
return d['version'], d['protocol']
def _parse_resp(self, name: str, data: str) -> O[Dict[str, Any]]:
try:
d = json.loads(data)
except json.JSONDecodeError as e:
raise subprocess.CalledProcessError(255, '', stderr=repr(e))
if d['type'] != 'response' or d['command'] != command:
return None
if not d['ok']:
raise subprocess.CalledProcessError(d['error']['code'], '', stderr=d['error']['msg'])
return d['values']
def _run(self, name: str, /, **kwargs) -> Dict[str, Any]:
first = True
while i < 10:
if not first and self.process:
self.process.stdin.close()
self.process.terminate()
self.process = None
first = False
self.process = self._open('weegee', 'remote', encoding='utf-8', timeout=5)
line = p.stdout.readline()
if not line:
continue
banner, protocol = self._parse_banner(line)
if banner is None:
continue
if protocol != self.PROTOCOL_VERSION:
raise subprocess.CalledProcessError(255, '', stderr=f'protocol version mismatch: expected {self.PROTOCOL_VERSION}, got {protocol}')
logger.debug('connected to remote %r', banner)
p.stdin.write(self._make_req(name, **kwargs) + '\n')
line = p.stdout.readline()
if not line:
continue
resp = self._parse_resp(name, line)
if resp is None:
continue
break
return resp
def get_interfaces(self) -> O[List[str]]:
resp = self._run('interface.query')
return resp['interfaces']
def has_interface(self, name: str) -> bool:
resp = self._run('interface.query', interface=name)
return bool(resp['interfaces'])
def get_config_format(self, name: str) -> WireguardConfigFormat:
resp = self._run('interface.info', interface=name)
return WireguardConfigFormat(resp['config_format'])
def create_interface(self, name: str) -> None:
self._run('interface.create', interface=name)
def destroy_interface(self, name: str) -> None:
self._run('interface.destroy', interface=name)
def set_mtu(self, name: str, mtu: int) -> None:
self._run('interface.set', interface=name, mtu=mtu)
def set_up(self, name: str) -> None:
self._run('interface.set', interface=name, up=True)
def set_down(self, name: str) -> None:
self._run('interface.set', interface=name, up=False)
def get_addresses(self, name: str) -> List[IPInterface]:
resp = self._run('address.query', interface=name)
return [ipaddress.ip_interface(x) for x in resp['addresses']]
def add_address(self, name: str, address: IPInterface) -> None:
self._run('address.create', interface=name, address=str(address))
def delete_address(self, name: str, address: IPInterface) -> None:
self._run('address.destroy', interface=name, address=str(address))
def get_routes(self, name: str) -> List[IPNetwork]:
resp = self._run('route.query', interface=name)
return [ipaddress.ip_network(x) for x in resp['routes']]
def add_route(self, name: str, route: IPNetwork) -> None:
self._run('route.create', interface=name, route=str(route))
def delete_route(self, name: str, route: IPNetwork) -> None:
self._run('route.destroy', interface=name, route=str(route))
def get_peers(self, name: str) -> List[str]:
resp = self._run('peer.query', interface=name)
return resp['peers']
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
resp = self._run('peer.info', interface=name, peer=peer)
if not resp:
return None
return WireguardPeerInfo(
datetime.fromtimestamp(resp['latest_handshake']) if resp['latest_handshake'] > 0 else None,
[(ipaddress.ip_address(x.split(':', 1)[0].strip('[]')), int(x.split(':', 1)[1])) for x in resp['endpoints']],
[ipaddress.ip_network(x) for x in resp['allowed_ips']],
resp['sent_bytes'],
resp['recv_bytes'],
)
def get_config(self, name: str) -> str:
resp = self._run('interface.info', interface=name)
return resp['config']
def set_config(self, name: str, config: str, sync: bool = True) -> None:
self._run('interface.configure', interface=name, config=config, sync=sync)
def gen_preshared_key(self) -> str:
resp = self._run('key.create', type='preshared')
return resp['key']
def gen_private_key(self) -> str:
resp = self._run('key.create', type='private')
return resp['key']
def get_public_key(self, private_key: str) -> str:
resp = self._run('key.convert', type='public', value=private_key)
return resp['key']
@dataclass
class WireguardPeer:
conn: WireguardConnection
interface: str
name: str
def _filter_list(self, cmd: str) -> O[List[str]]:
for line in self.conn._run_wg('show', self.interface, cmd).splitlines():
peer, *parts = line.split()
if peer == self.name:
return parts
return None
def latest_handshake(self) -> O[datetime]:
info = self.conn.get_peer_info(self.interface, self.name)
if not info:
return None
return info.latest_handshake
def last_handshake(self) -> O[datetime]:
parts = self._filter_list('latest-handshakes')
if not parts:
def latest_endpoint(self) -> O[Tuple[IPAddress, int]]:
info = self.conn.get_peer_info(self.interface, self.name)
if not info or not info.endpoints:
return None
ts = int(parts[0])
if not ts:
return None
return datetime.fromtimestamp(ts)
def last_endpoint(self) -> O[Tuple[U[ipaddress.IPv4Address, ipaddress.IPv6Address], int]]:
parts = self._filter_list('endpoints')
if not parts:
return None
if parts[0] == '(none)':
return None
addr, port = parts[0].rsplit(':', maxsplit=1)
return (ipaddress.ip_address(addr.strip('[]')), int(port))
return info.endpoints[-1]
def get_total_transfer(self) -> O[Tuple[int, int]]:
parts = self._filter_list('transfer')
if not parts:
info = self.conn.get_peer_info(self.interface, self.name)
if not info:
return None
return (int(parts[0]), int(parts[1]))
return (info.recv_bytes, info.sent_bytes)
@dataclass
class WireguardInterface:
@ -206,10 +422,10 @@ class WireguardInterface:
name: str
def list_peers(self) -> List[WireguardPeer]:
return [WireguardPeer(self.conn, self.name, x) for x in self.conn._run_wg('show', self.name, 'peers').split()]
return [WireguardPeer(self.conn, self.name, x) for x in self.conn.get_peers(self.name)]
def get_config(self) -> str:
return self.conn._run_wg('getconf', self.name)
return self.conn.get_config(self.name)
def sync(self, addresses: List[IPInterface], routes: List[IPNetwork], mtu: O[int] = None) -> None:
routes += [addr.network for addr in addresses]
@ -234,13 +450,13 @@ class WireguardInterface:
self.conn.add_route(self.name, wanted_routes[route])
def set_config(self, config: str) -> None:
self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin=config)
return self.conn.set_config(self.name, config, sync=False)
def sync_config(self, config: str) -> None:
self.conn._run_wg('syncconf', self.name, '/dev/stdin', stdin=config)
return self.conn.set_config(self.name, config, sync=True)
def clear_config(self) -> None:
self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin='')
return self.conn.clear_config(self.name)
def delete(self) -> None:
self.conn.delete_all_routes(self.name)
@ -253,7 +469,7 @@ class WireguardHost:
conn: WireguardConnection
def list_interfaces(self) -> List[str]:
return self.conn._run_wg('show', 'interfaces').split()
return self.conn.get_interfaces()
def get_interface(self, name: str) -> O[WireguardInterface]:
for interface in self.list_interfaces():
@ -267,20 +483,22 @@ class WireguardHost:
return self.get_interface(name)
def gen_preshared_key(self) -> str:
return self.conn._run_wg('genpsk')
return self.conn.gen_preshared_key()
def gen_private_key(self) -> str:
return self.conn._run_wg('genkey')
return self.conn.gen_private_key()
def get_public_key(self, privkey: str) -> str:
return self.conn._run_wg('pubkey', stdin=privkey)
return self.conn.get_public_key(privkey)
class WireguardHostType(enum.Enum):
Unsupported = None
Linux = 'linux'
Remote = 'remote'
def which_connection(type: WireguardHostType) -> Type[WireguardConnectionBase]:
return {
WireguardHostType.Linux: WireguardLinuxConnection,
WireguardHostType.Remote: WireguardRemoteConnection,
}.get(type, WireguardConnectionBase)