remote: remote config prototype
This commit is contained in:
parent
aa4c2eb60b
commit
7ee35a990e
|
@ -14,6 +14,7 @@ from .core import (
|
|||
do_discover_interface, do_config_interface, config_interface,
|
||||
do_sync_interface, sync_interface, sync_all_interfaces,
|
||||
)
|
||||
from .remote import WeegeeRemoteServer, REMOTE_COMMAND_SETS
|
||||
from .extra import WeegeeServer, WeegeeClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -12,9 +12,11 @@ logging.basicConfig(level=logging.DEBUG)
|
|||
from .dazy import Item, export_items, import_items
|
||||
from .wireguard import WireguardHostType, WireguardConfigFormat
|
||||
from . import (
|
||||
REMOTE_COMMAND_SETS,
|
||||
WeegeeContext, WeegeeConfig,
|
||||
WeegeeHook, WeegeeHost, WeegeePublicInterface, WeegeeInterface, WeegeePeer, WeegeeConnection,
|
||||
WeegeeServer, WeegeeClient,
|
||||
WeegeeRemoteServer,
|
||||
find_interface_other_peers, config_interface,
|
||||
sync_all_interfaces,
|
||||
setup,
|
||||
|
@ -94,12 +96,12 @@ def main():
|
|||
transfer = wg_peer.get_total_transfer()
|
||||
rx = siify(transfer[0]) if transfer else 'nothing'
|
||||
tx = siify(transfer[1]) if transfer else 'nothing'
|
||||
handshake = wg_peer.last_handshake()
|
||||
last_handshake = timestampify(handshake) if handshake else 'never'
|
||||
endpoint = wg_peer.last_endpoint()
|
||||
last_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else ''
|
||||
handshake = wg_peer.latest_handshake()
|
||||
latest_handshake = timestampify(handshake) if handshake else 'never'
|
||||
endpoint = wg_peer.latest_endpoint()
|
||||
latest_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else ''
|
||||
print(f' {peer.name} ({", ".join(str(x) for x in peer.interface.addresses)}, {wg_peer.name})')
|
||||
print(f' last handshake {last_handshake}{f" from {last_endpoint}" if last_endpoint else ""}, {tx} sent, {rx} received')
|
||||
print(f' latest handshake {latest_handshake}{f" from {latest_endpoint}" if latest_endpoint else ""}, {tx} sent, {rx} received')
|
||||
for wg_peer in wg_peers.values():
|
||||
print(f' ({wg_peer.name}): found but not expected')
|
||||
success = False
|
||||
|
@ -133,6 +135,41 @@ def main():
|
|||
import_cmd.add_argument('file', type=argparse.FileType('r', encoding='utf-8'), nargs='?', default=sys.stdin)
|
||||
import_cmd.set_defaults(func=do_import, parser=import_cmd)
|
||||
|
||||
def do_remote(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> None:
|
||||
if not args.host:
|
||||
config = ctx.get_config()
|
||||
args.host = config.default_remote_host or WeegeeHost.LOCAL_HOST_NAME
|
||||
|
||||
commands = set()
|
||||
for c in args.command:
|
||||
if c.startswith('@'):
|
||||
c = c[1:]
|
||||
if c not in REMOTE_COMMAND_SETS:
|
||||
parser.error(f'unknown command set: @{c}')
|
||||
commands.update(REMOTE_COMMAND_SETS[c])
|
||||
else:
|
||||
commands.add(c)
|
||||
|
||||
host = WeegeeHost.load(ctx, args.host)
|
||||
interfaces = set(args.interface)
|
||||
addresses = set(args.address)
|
||||
routes = set(args.route)
|
||||
server = WeegeeRemoteServer(host.conn.conn,
|
||||
allowed_commands=commands or None,
|
||||
allowed_interfaces=interfaces or None,
|
||||
allowed_addresses=addresses or None,
|
||||
allowed_routes=routes or None,
|
||||
)
|
||||
server.serve(sys.stdin, sys.stdout)
|
||||
|
||||
remote = commands.add_parser('remote', help='remote server')
|
||||
remote.add_argument('-c', '--command', type=str, action='append', default=[], help='allowed command(s) or @-prefixed command set(s)')
|
||||
remote.add_argument('-i', '--interface', type=str, action='append', default=[], help='allowed interface(s)')
|
||||
remote.add_argument('-a', '--address', type=ipaddress.ip_interface, action='append', default=[], help='allowed interface address(es)')
|
||||
remote.add_argument('-r', '--route', type=ipaddress.ip_network, action='append', default=[], help='allowed interface route(s)')
|
||||
remote.add_argument('host', help='host', nargs='?', default=None)
|
||||
remote.set_defaults(func=do_remote, parser=remote)
|
||||
|
||||
|
||||
# System commands
|
||||
|
||||
|
|
|
@ -228,9 +228,8 @@ class WeegeeHost(WeegeeBase):
|
|||
'configure',
|
||||
]
|
||||
|
||||
@property
|
||||
def config_format(self) -> O[WireguardConfigFormat]:
|
||||
return self.conn.conn.CONFIG_FORMAT
|
||||
def get_config_format(self, name: str) -> O[WireguardConfigFormat]:
|
||||
return self.conn.conn.get_config_format(name)
|
||||
|
||||
def get_hooks(self, name: str, **kwargs) -> WeegeeHookRunner:
|
||||
return WeegeeHookRunner([WeegeeHook(self.context, x) for x in getattr(self, f'{name}_hooks')], self.conn.conn, kwargs)
|
||||
|
@ -264,18 +263,12 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
|
|||
host: WeegeeHost
|
||||
inner: WireguardConnection
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.CONFIG_FORMAT = self.inner.CONFIG_FORMAT
|
||||
|
||||
def _run(self, *args, **kwargs) -> str:
|
||||
return self.inner._run(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_family(obj: U[IPAddress, IPInterface, IPNetwork]) -> str:
|
||||
return f'ipv{obj.version}'
|
||||
|
||||
def has_interface(self, name: str) -> bool:
|
||||
return self.inner.has_interface(name)
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.inner, name)
|
||||
|
||||
def create_interface(self, name: str) -> None:
|
||||
with self.host.get_hooks('interface_add', i=name):
|
||||
|
@ -285,18 +278,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
|
|||
with self.host.get_hooks('interface_del', i=name):
|
||||
return self.inner.destroy_interface(name)
|
||||
|
||||
def set_mtu(self, name: str, mtu: int) -> None:
|
||||
return self.inner.set_mtu(name, mtu)
|
||||
|
||||
def set_up(self, name: str) -> None:
|
||||
return self.inner.set_up(name)
|
||||
|
||||
def set_down(self, name: str) -> None:
|
||||
return self.inner.set_down(name)
|
||||
|
||||
def get_addresses(self, name: str) -> List[IPInterface]:
|
||||
return self.inner.get_addresses(name)
|
||||
|
||||
def add_address(self, name: str, address: IPInterface) -> None:
|
||||
family_hook = f'address_{self._get_family(address)}_add'
|
||||
with self.host.get_hooks('address_add', i=name, a=str(address)), self.host.get_hooks(family_hook, i=name, a=str(address)):
|
||||
|
@ -311,9 +292,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection):
|
|||
for address in self.get_addresses(name):
|
||||
self.delete_address(name, address)
|
||||
|
||||
def get_routes(self, name: str) -> List[IPNetwork]:
|
||||
return self.inner.get_routes(name)
|
||||
|
||||
def add_route(self, name: str, route: IPNetwork) -> None:
|
||||
family_hook = f'route_{self._get_family(route)}_add'
|
||||
with self.host.get_hooks('route_add', i=name, r=str(route)), self.host.get_hooks(family_hook, i=name, r=str(route)):
|
||||
|
@ -601,7 +579,7 @@ def do_sync_interface(interface: WeegeeInterface, peers: Set[WeegeePeer], connec
|
|||
for host in interface.hosts:
|
||||
if auto and not host.autosync:
|
||||
continue
|
||||
config = do_config_interface(interface, host.config_format, peers, connections, host)
|
||||
config = do_config_interface(interface, host.get_config_format(interface.interface_name), peers, connections, host)
|
||||
if not host.sync_interface(interface.interface_name, interface.mtu, interface.addresses, routes, config):
|
||||
raise ValueError(f'host {host.name} failed to sync interface!')
|
||||
return other_peers
|
||||
|
|
|
@ -102,8 +102,9 @@ WEEGEE_CONFIG = WeegeeMeta(
|
|||
name='wg/config',
|
||||
version=1,
|
||||
spec=[
|
||||
f'default_server_hosts: [str] = []',
|
||||
f'default_client_hosts: [str] = []',
|
||||
'default_server_hosts: [str] = []',
|
||||
'default_client_hosts: [str] = []',
|
||||
'default_remote_host: ?str = ',
|
||||
'log_level: str = "info"',
|
||||
'meta_path: str = "."',
|
||||
'data_path: ?str = ',
|
||||
|
|
|
@ -0,0 +1,269 @@
|
|||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import enum
|
||||
import ipaddress
|
||||
import inspect
|
||||
from typing import Optional as O, Tuple, Dict, Set, Any
|
||||
|
||||
from .wireguard import WireguardConnection, IPInterface, IPNetwork
|
||||
|
||||
|
||||
class WeegeeRemoteErrorCode(enum.Enum):
|
||||
InvalidCommand = -1
|
||||
InvalidParameter = -2
|
||||
NotFound = -3
|
||||
NotAllowed = -4
|
||||
Conflict = -5
|
||||
InternalError = -10
|
||||
|
||||
class WeegeeRemoteError(Exception):
|
||||
code: WeegeeRemoteErrorCode
|
||||
|
||||
def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
REMOTE_PROTOCOL_VERSION = 0
|
||||
REMOTE_COMMAND_SETS = {}
|
||||
REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'}
|
||||
REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'}
|
||||
REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | {
|
||||
'interface.create', 'interface.destroy',
|
||||
'address.query', 'address.create', 'address.destroy',
|
||||
'route.query', 'route.create', 'route.destroy',
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeegeeRemoteServer:
|
||||
conn: WireguardConnection
|
||||
allowed_commands: O[Set[str]] = None
|
||||
allowed_interfaces: O[Set[str]] = None
|
||||
allowed_addresses: O[Set[IPInterface]] = None
|
||||
allowed_routes: O[Set[IPNetwork]] = None
|
||||
handlers: Dict[str, Any] = field(init=False)
|
||||
handler_sigs: Dict[str, inspect.Signature] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.handlers = {
|
||||
'interface.query': self.handle_interface_query,
|
||||
'interface.create': self.handle_interface_create,
|
||||
'interface.destroy': self.handle_interface_destroy,
|
||||
'interface.set': self.handle_interface_set,
|
||||
'interface.info': self.handle_interface_info,
|
||||
'interface.configure': self.handle_interface_configure,
|
||||
'address.query': self.handle_address_query,
|
||||
'address.create': self.handle_address_create,
|
||||
'address.destroy': self.handle_address_destroy,
|
||||
'route.query': self.handle_route_query,
|
||||
'route.create': self.handle_route_create,
|
||||
'route.destroy': self.handle_route_destroy,
|
||||
'peer.query': self.handle_peer_query,
|
||||
'peer.info': self.handle_peer_info,
|
||||
'key.create': self.handle_key_create,
|
||||
'key.convert': self.handle_key_convert,
|
||||
}
|
||||
self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()}
|
||||
|
||||
|
||||
def _check_interface(self, interface: str) -> None:
|
||||
if self.allowed_interfaces is not None and interface not in self.allowed_interfaces:
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}')
|
||||
|
||||
def _check_existing_interface(self, interface: str) -> None:
|
||||
self._check_interface(interface)
|
||||
if not self.conn.has_interface(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}')
|
||||
|
||||
def _check_address(self, address: IPInterface) -> None:
|
||||
if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}')
|
||||
|
||||
def _check_route(self, route: IPNetwork) -> None:
|
||||
if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}')
|
||||
|
||||
|
||||
def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]:
|
||||
interfaces = self.conn.get_interfaces()
|
||||
return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]}
|
||||
|
||||
def handle_interface_create(self, interface: str) -> Dict[str, Any]:
|
||||
self._check_interface(interface)
|
||||
if self.conn.has_interface(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}')
|
||||
self.conn.create_interface(interface)
|
||||
return {}
|
||||
|
||||
def handle_interface_destroy(self, interface: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
self.conn.destroy_interface(interface)
|
||||
return {}
|
||||
|
||||
def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
if up is not None:
|
||||
if up:
|
||||
self.conn.set_up(interface)
|
||||
else:
|
||||
self.conn.set_down(interface)
|
||||
if mtu is not None:
|
||||
self.conn.set_mtu(interface, mtu)
|
||||
return {}
|
||||
|
||||
def handle_interface_info(self, interface: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value}
|
||||
|
||||
def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
if config is not None:
|
||||
self.conn.set_config(name, config, sync)
|
||||
return {}
|
||||
|
||||
def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
if address:
|
||||
address = ipaddress.ip_interface(address)
|
||||
return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]}
|
||||
|
||||
def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
address = ipaddress.ip_interface(address)
|
||||
self._check_address(address)
|
||||
if address in self.conn.get_addresses(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}')
|
||||
self.conn.add_address(interface, address)
|
||||
return {}
|
||||
|
||||
def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
address = ipaddress.ip_interface(address)
|
||||
self._check_address(address)
|
||||
if address not in self.conn.get_addresses(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}')
|
||||
self.conn.delete_address(interface, address)
|
||||
return {}
|
||||
|
||||
def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
if route:
|
||||
route = ipaddress.ip_network(route)
|
||||
return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]}
|
||||
|
||||
def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
route = ipaddress.ip_network(route)
|
||||
self._check_route(route)
|
||||
if route in self.conn.get_routes(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}')
|
||||
self.conn.add_route(interface, route)
|
||||
return {}
|
||||
|
||||
def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
route = ipaddress.ip_network(route)
|
||||
self._check_route(route)
|
||||
if route not in self.conn.get_addresses(interface):
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}')
|
||||
self.conn.delete_route(interface, route)
|
||||
return {}
|
||||
|
||||
def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]}
|
||||
|
||||
def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]:
|
||||
self._check_existing_interface(interface)
|
||||
info = self.conn.get_peer_info(interface, peer)
|
||||
if not info:
|
||||
return None
|
||||
return {
|
||||
'latest_handshake': int(info.latest_handshake.timestamp()),
|
||||
'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints],
|
||||
'allowed_ips': [str(x) for x in info.allowed_ips],
|
||||
'sent_bytes': info.sent_bytes,
|
||||
'recv_bytes': info.recv_bytes,
|
||||
}
|
||||
|
||||
def handle_key_create(self, type: str) -> Dict[str, Any]:
|
||||
if type == 'preshared':
|
||||
return {'key': self.conn.gen_preshared_key()}
|
||||
if type == 'private':
|
||||
return {'key': self.conn.gen_private_key()}
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
|
||||
|
||||
def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]:
|
||||
if type == 'public':
|
||||
return {'key': self.conn.get_public_key(value)}
|
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
|
||||
|
||||
|
||||
def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]:
|
||||
d = json.loads(data)
|
||||
if d['type'] != 'request':
|
||||
return None
|
||||
return d['command'], d['parameters']
|
||||
|
||||
def _make_banner(self) -> str:
|
||||
return json.dumps({
|
||||
'type': 'banner',
|
||||
'version': f'weegee/0.1',
|
||||
'protocol': REMOTE_PROTOCOL_VERSION,
|
||||
})
|
||||
|
||||
def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str:
|
||||
d = {
|
||||
'type': 'response',
|
||||
'command': command,
|
||||
}
|
||||
if error:
|
||||
d['ok'] = False
|
||||
d['error'] = kwargs
|
||||
else:
|
||||
d['ok'] = True
|
||||
d['values'] = kwargs
|
||||
return json.dumps(d)
|
||||
|
||||
def serve(self, infile, outfile) -> None:
|
||||
outfile.write(self._make_banner() + '\n')
|
||||
while True:
|
||||
line = infile.readline()
|
||||
if not line:
|
||||
break
|
||||
req = self._parse_req(line)
|
||||
if not req:
|
||||
break
|
||||
command, params = req
|
||||
|
||||
if command not in self.handlers:
|
||||
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'}
|
||||
error = True
|
||||
elif self.allowed_commands is not None and command not in self.allowed_commands:
|
||||
data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'}
|
||||
error = True
|
||||
else:
|
||||
signature = self.handler_sigs[command]
|
||||
try:
|
||||
signature.bind(**params)
|
||||
except TypeError as e:
|
||||
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)}
|
||||
error = True
|
||||
else:
|
||||
try:
|
||||
data = self.handlers[command](**params)
|
||||
error = False
|
||||
except WeegeeRemoteError as e:
|
||||
data = {'code': e.code.value, 'msg': str(e)}
|
||||
error = True
|
||||
except NotImplementedError as e:
|
||||
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'}
|
||||
error = True
|
||||
except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError):
|
||||
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'}
|
||||
error = True
|
||||
except Exception as e:
|
||||
data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)}
|
||||
error = True
|
||||
outfile.write(self._make_resp(command, error, **data) + '\n')
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
import shlex
|
||||
import ipaddress
|
||||
import enum
|
||||
import json
|
||||
from typing import Optional as O, Union as U, Type, Tuple, List
|
||||
import subprocess
|
||||
from logging import getLogger
|
||||
|
@ -16,6 +17,14 @@ IPNetwork = U[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
|||
IPInterface = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WireguardPeerInfo:
|
||||
latest_handshake: O[datetime]
|
||||
endpoints: List[Tuple[IPAddress, int]]
|
||||
allowed_ips: List[IPNetwork]
|
||||
sent_bytes: int
|
||||
recv_bytes: int
|
||||
|
||||
class WireguardConfigFormat(enum.Enum):
|
||||
WG = 'wg'
|
||||
WGQuick = 'wg-quick'
|
||||
|
@ -23,13 +32,8 @@ class WireguardConfigFormat(enum.Enum):
|
|||
|
||||
|
||||
class WireguardConnection:
|
||||
CONFIG_FORMAT: O[WireguardConfigFormat] = None
|
||||
|
||||
def _run(self, *args, shell=False, stdin=None, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _run_wg(self, *args, **kwargs) -> str:
|
||||
return self._run('wg', *args, **kwargs)
|
||||
def get_interfaces(self) -> O[List[str]]:
|
||||
return None
|
||||
|
||||
def has_interface(self, name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
@ -40,6 +44,9 @@ class WireguardConnection:
|
|||
def destroy_interface(self, name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
||||
return None
|
||||
|
||||
def set_mtu(self, name: str, mtu: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -50,7 +57,7 @@ class WireguardConnection:
|
|||
raise NotImplementedError
|
||||
|
||||
def get_addresses(self, name: str) -> List[IPInterface]:
|
||||
raise NotADirectoryError
|
||||
raise NotImplementedError
|
||||
|
||||
def add_address(self, name: str, address: IPInterface) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -75,6 +82,29 @@ class WireguardConnection:
|
|||
for route in self.get_routes(name):
|
||||
self.delete_route(name, route)
|
||||
|
||||
def gen_preshared_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def gen_private_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_public_key(self, private_key: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_peers(self, name: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self, name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_config(self, name: str) -> None:
|
||||
return self.set_config(name, '', sync=False)
|
||||
|
||||
@dataclass
|
||||
class WireguardConnectionBase(WireguardConnection):
|
||||
|
@ -82,7 +112,7 @@ class WireguardConnectionBase(WireguardConnection):
|
|||
user: O[str]
|
||||
elevate_user: O[str]
|
||||
|
||||
def _run(self, *args, shell=False, stdin=None, **kwargs) -> str:
|
||||
def _open(self, *args, shell=False, **kwargs) -> str:
|
||||
cmd = []
|
||||
logger.debug(f'#{"[" + self.host + "]" if self.host else ""} {args[0] if shell else shlex.join(args)}')
|
||||
if self.host:
|
||||
|
@ -97,20 +127,27 @@ class WireguardConnectionBase(WireguardConnection):
|
|||
else:
|
||||
cmd += list(args)
|
||||
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='ascii', **kwargs)
|
||||
|
||||
def _run(self, *args, stdin=None, **kwargs) -> str:
|
||||
p = self._open(*args, **kwargs)
|
||||
stdout, stderr = p.communicate(stdin)
|
||||
if p.returncode:
|
||||
raise subprocess.CalledProcessError(p.returncode, cmd, output=stdout, stderr=stderr)
|
||||
return stdout.strip()
|
||||
|
||||
class WireguardLinuxConnection(WireguardConnectionBase):
|
||||
CONFIG_FORMAT = WireguardConfigFormat.WG
|
||||
def _run_wg(self, *args, **kwargs) -> str:
|
||||
return self._run('wg', *args, **kwargs)
|
||||
|
||||
class WireguardLinuxConnection(WireguardConnectionBase):
|
||||
def _run_ip(self, *args) -> str:
|
||||
return self._run('ip', *args)
|
||||
|
||||
def _run_ip46(self, *args) -> str:
|
||||
return (self._run_ip('-4', *args) + '\n' + self._run_ip('-6', *args)).strip()
|
||||
|
||||
def get_interfaces(self) -> List[str]:
|
||||
return self._run_wg('show', 'interfaces').split()
|
||||
|
||||
def has_interface(self, name: str) -> bool:
|
||||
try:
|
||||
self._run_ip('link', 'show', name)
|
||||
|
@ -118,6 +155,9 @@ class WireguardLinuxConnection(WireguardConnectionBase):
|
|||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
||||
return WireguardConfigFormat.WG
|
||||
|
||||
def create_interface(self, name: str) -> None:
|
||||
self._run_ip('link', 'add', name, 'type', 'wireguard')
|
||||
|
||||
|
@ -163,42 +203,218 @@ class WireguardLinuxConnection(WireguardConnectionBase):
|
|||
def delete_route(self, name: str, route: IPNetwork) -> None:
|
||||
self._run_ip('route', 'del', str(route), 'dev', name)
|
||||
|
||||
def get_peers(self, name: str) -> List[str]:
|
||||
return self._run_wg('show', name, 'peers').split()
|
||||
|
||||
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
||||
for line in self.conn._run_wg('show', name, 'dump').splitlines():
|
||||
line_peer, *parts = line.split()
|
||||
if line_peer != peer:
|
||||
continue
|
||||
|
||||
psk, endpoints_str, allowed_ips_str, latest_handshake, rx, tx, keepalive, *_ = parts
|
||||
if endpoints_str == '(none)':
|
||||
endpoints = []
|
||||
else:
|
||||
endpoints = [x.split(':', maxsplit=1) for x in endpoints_str.split(',')]
|
||||
if allowed_ips_str == '(none)':
|
||||
allowed_ips = []
|
||||
else:
|
||||
allowed_ips = allowed_ips_str.split(',')
|
||||
latest_handshake_ts = int(latest_handshake)
|
||||
return WireguardPeerInfo(
|
||||
datetime.fromtimestamp(latest_handshake_ts) if latest_handshake_ts > 0 else None,
|
||||
[(ipaddress.ip_address(ip.strip('[]')), port) for (ip, port) in endpoints],
|
||||
[ipaddress.ip_network(x) for x in allowed_ips],
|
||||
int(tx),
|
||||
int(rx),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_config(self, name: str) -> str:
|
||||
return self._run_wg('getconf', name)
|
||||
|
||||
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
||||
if sync:
|
||||
self._run_wg('syncconf', name, '/dev/stdin', stdin=config)
|
||||
else:
|
||||
self._run_wg('setconf', name, '/dev/stdin', stdin=config)
|
||||
|
||||
def gen_preshared_key(self) -> str:
|
||||
return self._run_wg('genpsk')
|
||||
|
||||
def gen_private_key(self) -> str:
|
||||
return self._run_wg('genkey')
|
||||
|
||||
def get_public_key(self, private_key: str) -> str:
|
||||
return self._run_wg('pubkey', stdin=private_key)
|
||||
|
||||
class WireguardRemoteConnection(WireguardConnectionBase):
|
||||
process: subprocess.Popen
|
||||
PROTOCOL_VERSION = 0
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.process = None
|
||||
|
||||
def _make_req(self, name: str, /, **kwargs) -> str:
|
||||
return json.dumps({
|
||||
'type': 'request',
|
||||
'command': name,
|
||||
'parameters': kwargs,
|
||||
})
|
||||
|
||||
def _parse_banner(self, data: str) -> O[str]:
|
||||
d = json.loads(data)
|
||||
if d['type'] != 'banner':
|
||||
return None
|
||||
return d['version'], d['protocol']
|
||||
|
||||
def _parse_resp(self, name: str, data: str) -> O[Dict[str, Any]]:
|
||||
try:
|
||||
d = json.loads(data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise subprocess.CalledProcessError(255, '', stderr=repr(e))
|
||||
if d['type'] != 'response' or d['command'] != command:
|
||||
return None
|
||||
if not d['ok']:
|
||||
raise subprocess.CalledProcessError(d['error']['code'], '', stderr=d['error']['msg'])
|
||||
return d['values']
|
||||
|
||||
def _run(self, name: str, /, **kwargs) -> Dict[str, Any]:
|
||||
first = True
|
||||
while i < 10:
|
||||
if not first and self.process:
|
||||
self.process.stdin.close()
|
||||
self.process.terminate()
|
||||
self.process = None
|
||||
first = False
|
||||
self.process = self._open('weegee', 'remote', encoding='utf-8', timeout=5)
|
||||
line = p.stdout.readline()
|
||||
if not line:
|
||||
continue
|
||||
banner, protocol = self._parse_banner(line)
|
||||
if banner is None:
|
||||
continue
|
||||
if protocol != self.PROTOCOL_VERSION:
|
||||
raise subprocess.CalledProcessError(255, '', stderr=f'protocol version mismatch: expected {self.PROTOCOL_VERSION}, got {protocol}')
|
||||
logger.debug('connected to remote %r', banner)
|
||||
p.stdin.write(self._make_req(name, **kwargs) + '\n')
|
||||
line = p.stdout.readline()
|
||||
if not line:
|
||||
continue
|
||||
resp = self._parse_resp(name, line)
|
||||
if resp is None:
|
||||
continue
|
||||
break
|
||||
return resp
|
||||
|
||||
def get_interfaces(self) -> O[List[str]]:
|
||||
resp = self._run('interface.query')
|
||||
return resp['interfaces']
|
||||
|
||||
def has_interface(self, name: str) -> bool:
|
||||
resp = self._run('interface.query', interface=name)
|
||||
return bool(resp['interfaces'])
|
||||
|
||||
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
||||
resp = self._run('interface.info', interface=name)
|
||||
return WireguardConfigFormat(resp['config_format'])
|
||||
|
||||
def create_interface(self, name: str) -> None:
|
||||
self._run('interface.create', interface=name)
|
||||
|
||||
def destroy_interface(self, name: str) -> None:
|
||||
self._run('interface.destroy', interface=name)
|
||||
|
||||
def set_mtu(self, name: str, mtu: int) -> None:
|
||||
self._run('interface.set', interface=name, mtu=mtu)
|
||||
|
||||
def set_up(self, name: str) -> None:
|
||||
self._run('interface.set', interface=name, up=True)
|
||||
|
||||
def set_down(self, name: str) -> None:
|
||||
self._run('interface.set', interface=name, up=False)
|
||||
|
||||
def get_addresses(self, name: str) -> List[IPInterface]:
|
||||
resp = self._run('address.query', interface=name)
|
||||
return [ipaddress.ip_interface(x) for x in resp['addresses']]
|
||||
|
||||
def add_address(self, name: str, address: IPInterface) -> None:
|
||||
self._run('address.create', interface=name, address=str(address))
|
||||
|
||||
def delete_address(self, name: str, address: IPInterface) -> None:
|
||||
self._run('address.destroy', interface=name, address=str(address))
|
||||
|
||||
def get_routes(self, name: str) -> List[IPNetwork]:
|
||||
resp = self._run('route.query', interface=name)
|
||||
return [ipaddress.ip_network(x) for x in resp['routes']]
|
||||
|
||||
def add_route(self, name: str, route: IPNetwork) -> None:
|
||||
self._run('route.create', interface=name, route=str(route))
|
||||
|
||||
def delete_route(self, name: str, route: IPNetwork) -> None:
|
||||
self._run('route.destroy', interface=name, route=str(route))
|
||||
|
||||
def get_peers(self, name: str) -> List[str]:
|
||||
resp = self._run('peer.query', interface=name)
|
||||
return resp['peers']
|
||||
|
||||
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
||||
resp = self._run('peer.info', interface=name, peer=peer)
|
||||
if not resp:
|
||||
return None
|
||||
return WireguardPeerInfo(
|
||||
datetime.fromtimestamp(resp['latest_handshake']) if resp['latest_handshake'] > 0 else None,
|
||||
[(ipaddress.ip_address(x.split(':', 1)[0].strip('[]')), int(x.split(':', 1)[1])) for x in resp['endpoints']],
|
||||
[ipaddress.ip_network(x) for x in resp['allowed_ips']],
|
||||
resp['sent_bytes'],
|
||||
resp['recv_bytes'],
|
||||
)
|
||||
|
||||
def get_config(self, name: str) -> str:
|
||||
resp = self._run('interface.info', interface=name)
|
||||
return resp['config']
|
||||
|
||||
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
||||
self._run('interface.configure', interface=name, config=config, sync=sync)
|
||||
|
||||
def gen_preshared_key(self) -> str:
|
||||
resp = self._run('key.create', type='preshared')
|
||||
return resp['key']
|
||||
|
||||
def gen_private_key(self) -> str:
|
||||
resp = self._run('key.create', type='private')
|
||||
return resp['key']
|
||||
|
||||
def get_public_key(self, private_key: str) -> str:
|
||||
resp = self._run('key.convert', type='public', value=private_key)
|
||||
return resp['key']
|
||||
|
||||
@dataclass
|
||||
class WireguardPeer:
|
||||
conn: WireguardConnection
|
||||
interface: str
|
||||
name: str
|
||||
|
||||
def _filter_list(self, cmd: str) -> O[List[str]]:
|
||||
for line in self.conn._run_wg('show', self.interface, cmd).splitlines():
|
||||
peer, *parts = line.split()
|
||||
if peer == self.name:
|
||||
return parts
|
||||
def latest_handshake(self) -> O[datetime]:
|
||||
info = self.conn.get_peer_info(self.interface, self.name)
|
||||
if not info:
|
||||
return None
|
||||
return info.latest_handshake
|
||||
|
||||
def last_handshake(self) -> O[datetime]:
|
||||
parts = self._filter_list('latest-handshakes')
|
||||
if not parts:
|
||||
def latest_endpoint(self) -> O[Tuple[IPAddress, int]]:
|
||||
info = self.conn.get_peer_info(self.interface, self.name)
|
||||
if not info or not info.endpoints:
|
||||
return None
|
||||
ts = int(parts[0])
|
||||
if not ts:
|
||||
return None
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
def last_endpoint(self) -> O[Tuple[U[ipaddress.IPv4Address, ipaddress.IPv6Address], int]]:
|
||||
parts = self._filter_list('endpoints')
|
||||
if not parts:
|
||||
return None
|
||||
if parts[0] == '(none)':
|
||||
return None
|
||||
addr, port = parts[0].rsplit(':', maxsplit=1)
|
||||
return (ipaddress.ip_address(addr.strip('[]')), int(port))
|
||||
return info.endpoints[-1]
|
||||
|
||||
def get_total_transfer(self) -> O[Tuple[int, int]]:
|
||||
parts = self._filter_list('transfer')
|
||||
if not parts:
|
||||
info = self.conn.get_peer_info(self.interface, self.name)
|
||||
if not info:
|
||||
return None
|
||||
return (int(parts[0]), int(parts[1]))
|
||||
return (info.recv_bytes, info.sent_bytes)
|
||||
|
||||
@dataclass
|
||||
class WireguardInterface:
|
||||
|
@ -206,10 +422,10 @@ class WireguardInterface:
|
|||
name: str
|
||||
|
||||
def list_peers(self) -> List[WireguardPeer]:
|
||||
return [WireguardPeer(self.conn, self.name, x) for x in self.conn._run_wg('show', self.name, 'peers').split()]
|
||||
return [WireguardPeer(self.conn, self.name, x) for x in self.conn.get_peers(self.name)]
|
||||
|
||||
def get_config(self) -> str:
|
||||
return self.conn._run_wg('getconf', self.name)
|
||||
return self.conn.get_config(self.name)
|
||||
|
||||
def sync(self, addresses: List[IPInterface], routes: List[IPNetwork], mtu: O[int] = None) -> None:
|
||||
routes += [addr.network for addr in addresses]
|
||||
|
@ -234,13 +450,13 @@ class WireguardInterface:
|
|||
self.conn.add_route(self.name, wanted_routes[route])
|
||||
|
||||
def set_config(self, config: str) -> None:
|
||||
self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin=config)
|
||||
return self.conn.set_config(self.name, config, sync=False)
|
||||
|
||||
def sync_config(self, config: str) -> None:
|
||||
self.conn._run_wg('syncconf', self.name, '/dev/stdin', stdin=config)
|
||||
return self.conn.set_config(self.name, config, sync=True)
|
||||
|
||||
def clear_config(self) -> None:
|
||||
self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin='')
|
||||
return self.conn.clear_config(self.name)
|
||||
|
||||
def delete(self) -> None:
|
||||
self.conn.delete_all_routes(self.name)
|
||||
|
@ -253,7 +469,7 @@ class WireguardHost:
|
|||
conn: WireguardConnection
|
||||
|
||||
def list_interfaces(self) -> List[str]:
|
||||
return self.conn._run_wg('show', 'interfaces').split()
|
||||
return self.conn.get_interfaces()
|
||||
|
||||
def get_interface(self, name: str) -> O[WireguardInterface]:
|
||||
for interface in self.list_interfaces():
|
||||
|
@ -267,20 +483,22 @@ class WireguardHost:
|
|||
return self.get_interface(name)
|
||||
|
||||
def gen_preshared_key(self) -> str:
|
||||
return self.conn._run_wg('genpsk')
|
||||
return self.conn.gen_preshared_key()
|
||||
|
||||
def gen_private_key(self) -> str:
|
||||
return self.conn._run_wg('genkey')
|
||||
return self.conn.gen_private_key()
|
||||
|
||||
def get_public_key(self, privkey: str) -> str:
|
||||
return self.conn._run_wg('pubkey', stdin=privkey)
|
||||
return self.conn.get_public_key(privkey)
|
||||
|
||||
|
||||
class WireguardHostType(enum.Enum):
|
||||
Unsupported = None
|
||||
Linux = 'linux'
|
||||
Remote = 'remote'
|
||||
|
||||
def which_connection(type: WireguardHostType) -> Type[WireguardConnectionBase]:
|
||||
return {
|
||||
WireguardHostType.Linux: WireguardLinuxConnection,
|
||||
WireguardHostType.Remote: WireguardRemoteConnection,
|
||||
}.get(type, WireguardConnectionBase)
|
||||
|
|
Loading…
Reference in New Issue