505 lines
18 KiB
Python
505 lines
18 KiB
Python
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
import shlex
|
|
import ipaddress
|
|
import enum
|
|
import json
|
|
from typing import Optional as O, Union as U, Type, Tuple, List
|
|
import subprocess
|
|
from logging import getLogger
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
IPAddress = U[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
|
IPNetwork = U[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
|
IPInterface = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
|
|
|
|
|
|
@dataclass
|
|
class WireguardPeerInfo:
|
|
latest_handshake: O[datetime]
|
|
endpoints: List[Tuple[IPAddress, int]]
|
|
allowed_ips: List[IPNetwork]
|
|
sent_bytes: int
|
|
recv_bytes: int
|
|
|
|
class WireguardConfigFormat(enum.Enum):
|
|
WG = 'wg'
|
|
WGQuick = 'wg-quick'
|
|
OpenWrt = 'openwrt'
|
|
|
|
|
|
class WireguardConnection:
|
|
def get_interfaces(self) -> O[List[str]]:
|
|
return None
|
|
|
|
def has_interface(self, name: str) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def create_interface(self, name: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
def destroy_interface(self, name: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
|
return None
|
|
|
|
def set_mtu(self, name: str, mtu: int) -> None:
|
|
raise NotImplementedError
|
|
|
|
def set_up(self, name: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
def set_down(self, name: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
def get_addresses(self, name: str) -> List[IPInterface]:
|
|
raise NotImplementedError
|
|
|
|
def add_address(self, name: str, address: IPInterface) -> None:
|
|
raise NotImplementedError
|
|
|
|
def delete_address(self, name: str, address: IPInterface) -> None:
|
|
raise NotImplementedError
|
|
|
|
def delete_all_addresses(self, name: str) -> None:
|
|
for address in self.get_addresses(name):
|
|
self.delete_address(name, address)
|
|
|
|
def get_routes(self, name: str) -> List[IPNetwork]:
|
|
raise NotImplementedError
|
|
|
|
def add_route(self, name: str, route: IPNetwork) -> None:
|
|
raise NotImplementedError
|
|
|
|
def delete_route(self, name: str, route: IPNetwork) -> None:
|
|
raise NotImplementedError
|
|
|
|
def delete_all_routes(self, name: str) -> None:
|
|
for route in self.get_routes(name):
|
|
self.delete_route(name, route)
|
|
|
|
def gen_preshared_key(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def gen_private_key(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def get_public_key(self, private_key: str) -> str:
|
|
raise NotImplementedError
|
|
|
|
def get_peers(self, name: str) -> List[str]:
|
|
raise NotImplementedError
|
|
|
|
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
|
raise NotImplementedError
|
|
|
|
def get_config(self, name: str) -> str:
|
|
raise NotImplementedError
|
|
|
|
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
|
raise NotImplementedError
|
|
|
|
def clear_config(self, name: str) -> None:
|
|
return self.set_config(name, '', sync=False)
|
|
|
|
@dataclass
|
|
class WireguardConnectionBase(WireguardConnection):
|
|
host: O[str]
|
|
user: O[str]
|
|
elevate_user: O[str]
|
|
|
|
def _open(self, *args, shell=False, **kwargs) -> str:
|
|
cmd = []
|
|
logger.debug(f'#{"[" + self.host + "]" if self.host else ""} {args[0] if shell else shlex.join(args)}')
|
|
if self.host:
|
|
cmd += ['ssh', f'{self.user}@{self.host}' if self.user else self.host]
|
|
if self.elevate_user:
|
|
cmd += ['sudo', '-i', '-u', self.elevate_user]
|
|
if shell:
|
|
s = 'set -e; ' + args[0]
|
|
if self.host:
|
|
s = shlex.quote(s)
|
|
cmd += ['sh', '-c', s]
|
|
else:
|
|
cmd += list(args)
|
|
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='ascii', **kwargs)
|
|
|
|
def _run(self, *args, stdin=None, **kwargs) -> str:
|
|
p = self._open(*args, **kwargs)
|
|
stdout, stderr = p.communicate(stdin)
|
|
if p.returncode:
|
|
raise subprocess.CalledProcessError(p.returncode, cmd, output=stdout, stderr=stderr)
|
|
return stdout.strip()
|
|
|
|
def _run_wg(self, *args, **kwargs) -> str:
|
|
return self._run('wg', *args, **kwargs)
|
|
|
|
class WireguardLinuxConnection(WireguardConnectionBase):
|
|
def _run_ip(self, *args) -> str:
|
|
return self._run('ip', *args)
|
|
|
|
def _run_ip46(self, *args) -> str:
|
|
return (self._run_ip('-4', *args) + '\n' + self._run_ip('-6', *args)).strip()
|
|
|
|
def get_interfaces(self) -> List[str]:
|
|
return self._run_wg('show', 'interfaces').split()
|
|
|
|
def has_interface(self, name: str) -> bool:
|
|
try:
|
|
self._run_ip('link', 'show', name)
|
|
return True
|
|
except subprocess.CalledProcessError:
|
|
return False
|
|
|
|
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
|
return WireguardConfigFormat.WG
|
|
|
|
def create_interface(self, name: str) -> None:
|
|
self._run_ip('link', 'add', name, 'type', 'wireguard')
|
|
|
|
def destroy_interface(self, name: str) -> None:
|
|
self._run_ip('link', 'del', name)
|
|
|
|
def set_mtu(self, name: str, mtu: int) -> None:
|
|
self._run_ip('link', 'set', name, 'mtu', str(mtu))
|
|
|
|
def set_up(self, name: str) -> None:
|
|
self._run_ip('link', 'set', name, 'up')
|
|
|
|
def set_down(self, name: str) -> None:
|
|
self._run_ip('link', 'set', name, 'down')
|
|
|
|
def get_addresses(self, name: str) -> List[IPInterface]:
|
|
addresses = []
|
|
for line in self._run_ip46('addr', 'show', 'dev', name).splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
parts = line.split()
|
|
if parts[0] in ('inet', 'inet6'):
|
|
addresses.append(ipaddress.ip_interface(parts[1]))
|
|
return addresses
|
|
|
|
def add_address(self, name: str, address: IPInterface) -> None:
|
|
self._run_ip('addr', 'add', str(address), 'dev', name)
|
|
|
|
def delete_address(self, name: str, address: IPInterface) -> None:
|
|
self._run_ip('addr', 'del', str(address), 'dev', name)
|
|
|
|
def get_routes(self, name: str) -> List[IPNetwork]:
|
|
routes = []
|
|
for line in self._run_ip46('route', 'list', 'dev', name).splitlines():
|
|
route = line.split()[0].strip()
|
|
routes.append(ipaddress.ip_network(route))
|
|
return routes
|
|
|
|
def add_route(self, name: str, route: IPNetwork) -> None:
|
|
self._run_ip('route', 'add', str(route), 'dev', name)
|
|
|
|
def delete_route(self, name: str, route: IPNetwork) -> None:
|
|
self._run_ip('route', 'del', str(route), 'dev', name)
|
|
|
|
def get_peers(self, name: str) -> List[str]:
|
|
return self._run_wg('show', name, 'peers').split()
|
|
|
|
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
|
for line in self.conn._run_wg('show', name, 'dump').splitlines():
|
|
line_peer, *parts = line.split()
|
|
if line_peer != peer:
|
|
continue
|
|
|
|
psk, endpoints_str, allowed_ips_str, latest_handshake, rx, tx, keepalive, *_ = parts
|
|
if endpoints_str == '(none)':
|
|
endpoints = []
|
|
else:
|
|
endpoints = [x.split(':', maxsplit=1) for x in endpoints_str.split(',')]
|
|
if allowed_ips_str == '(none)':
|
|
allowed_ips = []
|
|
else:
|
|
allowed_ips = allowed_ips_str.split(',')
|
|
latest_handshake_ts = int(latest_handshake)
|
|
return WireguardPeerInfo(
|
|
datetime.fromtimestamp(latest_handshake_ts) if latest_handshake_ts > 0 else None,
|
|
[(ipaddress.ip_address(ip.strip('[]')), port) for (ip, port) in endpoints],
|
|
[ipaddress.ip_network(x) for x in allowed_ips],
|
|
int(tx),
|
|
int(rx),
|
|
)
|
|
|
|
return None
|
|
|
|
def get_config(self, name: str) -> str:
|
|
return self._run_wg('getconf', name)
|
|
|
|
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
|
if sync:
|
|
self._run_wg('syncconf', name, '/dev/stdin', stdin=config)
|
|
else:
|
|
self._run_wg('setconf', name, '/dev/stdin', stdin=config)
|
|
|
|
def gen_preshared_key(self) -> str:
|
|
return self._run_wg('genpsk')
|
|
|
|
def gen_private_key(self) -> str:
|
|
return self._run_wg('genkey')
|
|
|
|
def get_public_key(self, private_key: str) -> str:
|
|
return self._run_wg('pubkey', stdin=private_key)
|
|
|
|
class WireguardRemoteConnection(WireguardConnectionBase):
|
|
process: subprocess.Popen
|
|
PROTOCOL_VERSION = 0
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.process = None
|
|
|
|
def _make_req(self, name: str, /, **kwargs) -> str:
|
|
return json.dumps({
|
|
'type': 'request',
|
|
'command': name,
|
|
'parameters': kwargs,
|
|
})
|
|
|
|
def _parse_banner(self, data: str) -> O[str]:
|
|
d = json.loads(data)
|
|
if d['type'] != 'banner':
|
|
return None
|
|
return d['version'], d['protocol']
|
|
|
|
def _parse_resp(self, name: str, data: str) -> O[Dict[str, Any]]:
|
|
try:
|
|
d = json.loads(data)
|
|
except json.JSONDecodeError as e:
|
|
raise subprocess.CalledProcessError(255, '', stderr=repr(e))
|
|
if d['type'] != 'response' or d['command'] != command:
|
|
return None
|
|
if not d['ok']:
|
|
raise subprocess.CalledProcessError(d['error']['code'], '', stderr=d['error']['msg'])
|
|
return d['values']
|
|
|
|
def _run(self, name: str, /, **kwargs) -> Dict[str, Any]:
|
|
first = True
|
|
while i < 10:
|
|
if not first and self.process:
|
|
self.process.stdin.close()
|
|
self.process.terminate()
|
|
self.process = None
|
|
first = False
|
|
self.process = self._open('weegee', 'remote', encoding='utf-8', timeout=5)
|
|
line = p.stdout.readline()
|
|
if not line:
|
|
continue
|
|
banner, protocol = self._parse_banner(line)
|
|
if banner is None:
|
|
continue
|
|
if protocol != self.PROTOCOL_VERSION:
|
|
raise subprocess.CalledProcessError(255, '', stderr=f'protocol version mismatch: expected {self.PROTOCOL_VERSION}, got {protocol}')
|
|
logger.debug('connected to remote %r', banner)
|
|
p.stdin.write(self._make_req(name, **kwargs) + '\n')
|
|
line = p.stdout.readline()
|
|
if not line:
|
|
continue
|
|
resp = self._parse_resp(name, line)
|
|
if resp is None:
|
|
continue
|
|
break
|
|
return resp
|
|
|
|
def get_interfaces(self) -> O[List[str]]:
|
|
resp = self._run('interface.query')
|
|
return resp['interfaces']
|
|
|
|
def has_interface(self, name: str) -> bool:
|
|
resp = self._run('interface.query', interface=name)
|
|
return bool(resp['interfaces'])
|
|
|
|
def get_config_format(self, name: str) -> WireguardConfigFormat:
|
|
resp = self._run('interface.info', interface=name)
|
|
return WireguardConfigFormat(resp['config_format'])
|
|
|
|
def create_interface(self, name: str) -> None:
|
|
self._run('interface.create', interface=name)
|
|
|
|
def destroy_interface(self, name: str) -> None:
|
|
self._run('interface.destroy', interface=name)
|
|
|
|
def set_mtu(self, name: str, mtu: int) -> None:
|
|
self._run('interface.set', interface=name, mtu=mtu)
|
|
|
|
def set_up(self, name: str) -> None:
|
|
self._run('interface.set', interface=name, up=True)
|
|
|
|
def set_down(self, name: str) -> None:
|
|
self._run('interface.set', interface=name, up=False)
|
|
|
|
def get_addresses(self, name: str) -> List[IPInterface]:
|
|
resp = self._run('address.query', interface=name)
|
|
return [ipaddress.ip_interface(x) for x in resp['addresses']]
|
|
|
|
def add_address(self, name: str, address: IPInterface) -> None:
|
|
self._run('address.create', interface=name, address=str(address))
|
|
|
|
def delete_address(self, name: str, address: IPInterface) -> None:
|
|
self._run('address.destroy', interface=name, address=str(address))
|
|
|
|
def get_routes(self, name: str) -> List[IPNetwork]:
|
|
resp = self._run('route.query', interface=name)
|
|
return [ipaddress.ip_network(x) for x in resp['routes']]
|
|
|
|
def add_route(self, name: str, route: IPNetwork) -> None:
|
|
self._run('route.create', interface=name, route=str(route))
|
|
|
|
def delete_route(self, name: str, route: IPNetwork) -> None:
|
|
self._run('route.destroy', interface=name, route=str(route))
|
|
|
|
def get_peers(self, name: str) -> List[str]:
|
|
resp = self._run('peer.query', interface=name)
|
|
return resp['peers']
|
|
|
|
def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]:
|
|
resp = self._run('peer.info', interface=name, peer=peer)
|
|
if not resp:
|
|
return None
|
|
return WireguardPeerInfo(
|
|
datetime.fromtimestamp(resp['latest_handshake']) if resp['latest_handshake'] > 0 else None,
|
|
[(ipaddress.ip_address(x.split(':', 1)[0].strip('[]')), int(x.split(':', 1)[1])) for x in resp['endpoints']],
|
|
[ipaddress.ip_network(x) for x in resp['allowed_ips']],
|
|
resp['sent_bytes'],
|
|
resp['recv_bytes'],
|
|
)
|
|
|
|
def get_config(self, name: str) -> str:
|
|
resp = self._run('interface.info', interface=name)
|
|
return resp['config']
|
|
|
|
def set_config(self, name: str, config: str, sync: bool = True) -> None:
|
|
self._run('interface.configure', interface=name, config=config, sync=sync)
|
|
|
|
def gen_preshared_key(self) -> str:
|
|
resp = self._run('key.create', type='preshared')
|
|
return resp['key']
|
|
|
|
def gen_private_key(self) -> str:
|
|
resp = self._run('key.create', type='private')
|
|
return resp['key']
|
|
|
|
def get_public_key(self, private_key: str) -> str:
|
|
resp = self._run('key.convert', type='public', value=private_key)
|
|
return resp['key']
|
|
|
|
@dataclass
|
|
class WireguardPeer:
|
|
conn: WireguardConnection
|
|
interface: str
|
|
name: str
|
|
|
|
def latest_handshake(self) -> O[datetime]:
|
|
info = self.conn.get_peer_info(self.interface, self.name)
|
|
if not info:
|
|
return None
|
|
return info.latest_handshake
|
|
|
|
def latest_endpoint(self) -> O[Tuple[IPAddress, int]]:
|
|
info = self.conn.get_peer_info(self.interface, self.name)
|
|
if not info or not info.endpoints:
|
|
return None
|
|
return info.endpoints[-1]
|
|
|
|
def get_total_transfer(self) -> O[Tuple[int, int]]:
|
|
info = self.conn.get_peer_info(self.interface, self.name)
|
|
if not info:
|
|
return None
|
|
return (info.recv_bytes, info.sent_bytes)
|
|
|
|
@dataclass
|
|
class WireguardInterface:
|
|
conn: WireguardConnection
|
|
name: str
|
|
|
|
def list_peers(self) -> List[WireguardPeer]:
|
|
return [WireguardPeer(self.conn, self.name, x) for x in self.conn.get_peers(self.name)]
|
|
|
|
def get_config(self) -> str:
|
|
return self.conn.get_config(self.name)
|
|
|
|
def sync(self, addresses: List[IPInterface], routes: List[IPNetwork], mtu: O[int] = None) -> None:
|
|
routes += [addr.network for addr in addresses]
|
|
|
|
if mtu is not None:
|
|
self.conn.set_mtu(self.name, mtu)
|
|
|
|
# Work around IPInterface being unhashable
|
|
wanted_addresses = {str(x): x for x in addresses}
|
|
interface_addresses = {str(x): x for x in self.conn.get_addresses(self.name)}
|
|
for address in set(interface_addresses) - set(wanted_addresses):
|
|
self.conn.delete_address(self.name, interface_addresses[address])
|
|
for address in set(wanted_addresses) - set(interface_addresses):
|
|
self.conn.add_address(self.name, wanted_addresses[address])
|
|
|
|
# Work around IPNetwork being unhashable
|
|
wanted_routes = {str(x): x for x in routes}
|
|
interface_routes = {str(x): x for x in self.conn.get_routes(self.name)}
|
|
for route in set(interface_routes) - set(wanted_routes):
|
|
self.conn.delete_route(self.name, interface_routes[route])
|
|
for route in set(wanted_routes) - set(interface_routes):
|
|
self.conn.add_route(self.name, wanted_routes[route])
|
|
|
|
def set_config(self, config: str) -> None:
|
|
return self.conn.set_config(self.name, config, sync=False)
|
|
|
|
def sync_config(self, config: str) -> None:
|
|
return self.conn.set_config(self.name, config, sync=True)
|
|
|
|
def clear_config(self) -> None:
|
|
return self.conn.clear_config(self.name)
|
|
|
|
def delete(self) -> None:
|
|
self.conn.delete_all_routes(self.name)
|
|
self.conn.delete_all_addresses(self.name)
|
|
self.conn.set_down(self.name)
|
|
self.conn.destroy_interface(self.name)
|
|
|
|
@dataclass
|
|
class WireguardHost:
|
|
conn: WireguardConnection
|
|
|
|
def list_interfaces(self) -> List[str]:
|
|
return self.conn.get_interfaces()
|
|
|
|
def get_interface(self, name: str) -> O[WireguardInterface]:
|
|
for interface in self.list_interfaces():
|
|
if interface == name:
|
|
return WireguardInterface(self.conn, name)
|
|
return None
|
|
|
|
def create_interface(self, name: str, mtu: O[int] = None) -> O[WireguardInterface]:
|
|
self.conn.create_interface(name)
|
|
self.conn.set_up(name)
|
|
return self.get_interface(name)
|
|
|
|
def gen_preshared_key(self) -> str:
|
|
return self.conn.gen_preshared_key()
|
|
|
|
def gen_private_key(self) -> str:
|
|
return self.conn.gen_private_key()
|
|
|
|
def get_public_key(self, privkey: str) -> str:
|
|
return self.conn.get_public_key(privkey)
|
|
|
|
|
|
class WireguardHostType(enum.Enum):
|
|
Unsupported = None
|
|
Linux = 'linux'
|
|
Remote = 'remote'
|
|
|
|
def which_connection(type: WireguardHostType) -> Type[WireguardConnectionBase]:
|
|
return {
|
|
WireguardHostType.Linux: WireguardLinuxConnection,
|
|
WireguardHostType.Remote: WireguardRemoteConnection,
|
|
}.get(type, WireguardConnectionBase)
|