commit 8599ec36e7623dc3acbc0e05939f48d309a684dc Author: Shiz Date: Sun Dec 5 20:06:26 2021 +0100 epoch diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c070f26 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.pyc + +/configs +/items +/templates diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7f7afbf --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +jinja2 diff --git a/weegee/__init__.py b/weegee/__init__.py new file mode 100644 index 0000000..0691b1c --- /dev/null +++ b/weegee/__init__.py @@ -0,0 +1,350 @@ +from logging import getLogger +from dataclasses import dataclass +from typing import Optional as O, Any +from .dazy import Instance, Config, Meta, Item, Template +from .wireguard import WireguardConnection, WireguardPeer +from .desc import WEEGEE_SERVER, WEEGEE_CLIENT, WEEGEE_HOST, WEEGEE_CONFIG, WEEGEE_SERVER_CONF, WEEGEE_CLIENT_CONF + +logger = getLogger(__name__) + + +@dataclass +class WeegeeBase: + BASE = None + context: 'WeegeeContext' + item: Item + meta: Meta = None + resolved: Item = None + + def __post_init__(self) -> None: + if not self.meta: + self.meta = Meta.load(self.context.instance, self.BASE.get_name()) + if not self.resolved: + self.resolved = self.item.resolve(self.meta) + + @classmethod + def get_name(cls, name: str) -> str: + return f'{cls.BASE.item_prefix}/{name}' + + @classmethod + def get_meta(cls, context: 'WeegeeContext') -> Meta: + return Meta.load(context.instance, cls.BASE.get_name()) + + @classmethod + def create(cls, context: 'WeegeeContext', config: Config) -> 'WeegeeBase': + meta = cls.get_meta(context) + if not config.matches(meta.config): + raise TypeError('internal error') + return cls(context, Item(config), meta) + + @classmethod + def load(cls, context: 'WeegeeContext', name: str) -> 'WeegeeBase': + return cls(context, Item.load(context.instance, name)) + + @property + def name(self) -> str: + return self.item_name[len(self.get_name('')):] + + @property + def item_name(self) -> str: + return self.item.name + + def save(self) -> None: + self.item.save() + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.item.name == other.item.name + + +@dataclass +class WeegeeHost(WeegeeBase): + BASE = WEEGEE_HOST + conn: WireguardConnection = None + + @classmethod + def create(cls, ctx: 'WeegeeContext', name: str, host: O[str] = None, user: O[str] = None, elevate_user: O[str] = None, automanage: bool = False, autosync: bool = False) -> 'WeegeeHost': + confname = cls.get_name(name) + return super().create(ctx, Config.parse(ctx.instance, confname, [ + f'automanage = {1 if automanage else 0}', + f'autosync = {1 if autosync else 0}', + f'host = {repr(host) if host else ""}', + f'user = {repr(user) if user else ""}', + f'elevate_user = {repr(elevate_user) if elevate_user else ""}', + ])) + + @property + def automanage(self) -> bool: + return self.resolved.config['automanage'] == 1 + + @property + def autosync(self) -> bool: + return self.resolved.config['autosync'] == 1 + + @property + def host(self) -> O[str]: + return self.resolved.config['host'] + + @property + def user(self) -> O[str]: + return self.resolved.config['user'] + + @property + def elevate_user(self) -> O[str]: + return self.resolved.config['elevate_user'] + + def __post_init__(self) -> None: + super().__post_init__() + self.conn = WireguardConnection(self.host, self.user, self.elevate_user) + + def get_peers(self, name: str) -> O[list[WireguardPeer]]: + interface = self.conn.get_interface(name) + if not interface: + return None + return interface.list_peers() + + def sync_config(self, name: str, config: str) -> None: + interface = self.conn.get_interface(name) + if not interface: + if not self.automanage: + return + interface = self.conn.create_interface(name, config) + else: + interface.sync_config(config) + + def remove_config(self, name: str) -> None: + interface = self.conn.get_interface(name) + if interface: + interface.sync_config('') + if self.automanage: + interface.delete() + +@dataclass +class WeegeeClient(WeegeeBase): + BASE = WEEGEE_CLIENT + + @classmethod + def get_name(cls, server: 'WeegeeServer', name: str) -> str: + return f'{cls.BASE.item_prefix}/{server.resolved.config["name"]}/{name}' + + @property + def name(self) -> str: + return self.item_name[len(self.get_name(self.server, '')):] + + @classmethod + def create(cls, ctx: 'WeegeeContext', server: 'WeegeeServer', name: str, private_key: O[str] = None, public_key: O[str] = None, preshared_key: O[str] = None, addresses: list[str] = [], hosts: O[list[WeegeeHost]] = None) -> 'WeegeeClient': + hosts = hosts or ctx.get_config().default_client_hosts + for serv in [ctx.get_local_host()] + hosts: + try: + private_key = private_key or serv.conn.gen_private_key() + public_key = public_key or ctx.wg.get_public_key(private_key) + preshared_key = preshared_key or ctx.wg.gen_preshared_key() + break + except Exception as e: + logger.warn(f'could not generate keypair on host {serv.item_name}: {e!r}') + else: + logger.critical('could not generate public/private keypair automatically: please pass explicitly!') + raise ValueError() + + confname = cls.get_name(server, name) + config = Config.parse(ctx.instance, confname, [ + f'hosts = [{", ".join(h.item_name for h in hosts)}]', + f'server = {server.item_name}', + f'name = {name!r}', + f'public_key = {public_key!r}', + f'private_key = {private_key!r}', + f'preshared_key = {preshared_key!r}', + f'addresses = [{", ".join(addresses)}]', + ]) + return super().create(ctx, config) + + @property + def interface(self) -> str: + return self.resolved.config['interface'] + + @property + def public_key(self) -> str: + return self.resolved.config['public_key'] + + @property + def hosts(self) -> list[WeegeeHost]: + return [WeegeeHost(self.context, x) for x in self.resolved.config['hosts']] + + @property + def server(self) -> 'WeegeeServer': + return WeegeeServer(self.context, self.resolved.config['server']) + + def save(self) -> None: + super().save() + self.sync(auto=True) + + def sync(self, auto=False, log=None) -> None: + if not log: + log = set() + if self.item_name in log: + return + log.add(self.item_name) + logger.info(f'syncing: {self.item_name}') + + for host in self.hosts: + if auto and not host.autosync: + continue + host.sync_config(self.interface, self.gen_config()) + self.server.sync(auto=auto, log=log) + + def gen_config(self) -> str: + template = Template.load(self.context.instance, WEEGEE_CLIENT_CONF.get_name()) + config = WEEGEE_CLIENT_CONF.make_config(self.context.instance, + client=self.item, + ) + return template.render(config) + +@dataclass +class WeegeeServer(WeegeeBase): + BASE = WEEGEE_SERVER + + def get_clients(self) -> list[WeegeeClient]: + return [WeegeeClient.load(self.context, x) for x in Item.filter(self.context.instance, WeegeeClient.get_name(self, '*'))] + + def get_client(self, name: str) -> O[WeegeeClient]: + return WeegeeClient.load(self.context, WeegeeClient.get_name(self, name)) + + @classmethod + def create(cls, ctx: 'WeegeeContext', name: str, interface: str, host: tuple[str, int], private_key: O[str] = None, public_key: O[str] = None, addresses: list[str] = [], routed_addresses: list[str] = [], hosts: O[list[WeegeeHost]] = None) -> 'WeegeeServer': + hosts = hosts or ctx.get_config().default_server_hosts + for serv in [ctx.get_local_host()] + hosts: + try: + private_key = private_key or serv.conn.gen_private_key() + public_key = public_key or serv.conn.get_public_key(private_key) + break + except Exception as e: + logger.warn(f'could not generate keypair on host {serv.item_name}: {e!r}') + else: + logger.critical('could not generate public/private keypair automatically: please pass explicitly!') + raise ValueError() + + confname = cls.get_name(name) + config = Config.parse(ctx.instance, confname, [ + f'hosts = [{", ".join(h.item_name for h in hosts)}]', + f'name = {name!r}', + f'interface = {interface!r}', + f'public_key = {public_key!r}', + f'private_key = {private_key!r}', + f'addresses = [{", ".join(addresses)}]', + f'routed_addresses = [{", ".join(routed_addresses)}]', + f'host = {host[0]!r}', + f'port = {host[1]}', + ]) + return super().create(ctx, config) + + @property + def interface(self): + return self.resolved.config['interface'] + + @property + def hosts(self) -> list[WeegeeHost]: + return [WeegeeHost(self.context, x) for x in self.resolved.config['hosts']] + + def save(self) -> None: + super().save() + self.sync(auto=True) + + def sync(self, auto=False, log=None) -> None: + if not log: + log = set() + if self.item_name in log: + return + log.add(self.item_name) + logger.info(f'syncing: {self.item_name}') + + for host in self.hosts: + if auto and not host.autosync: + continue + host.sync_config(self.interface, self.gen_config()) + for client in self.get_clients(): + client.sync(auto=auto, log=log) + + def gen_config(self) -> str: + clients = self.get_clients() + template = Template.load(self.context.instance, WEEGEE_SERVER_CONF.get_name()) + config = WEEGEE_SERVER_CONF.make_config(self.context.instance, + server=self.item, + clients=[x.item for x in clients], + ) + return template.render(config) + + +@dataclass +class WeegeeConfig: + context: 'WeegeeContext' + item: Item + + @classmethod + def get_name(cls) -> str: + return WEEGEE_CONFIG.get_name() + + @classmethod + def load(cls, context: 'WeegeeContext', name: str) -> 'WeegeeConfig': + return cls(context, Meta.load(context.instance, name)) + + def save(self) -> None: + self.item.save() + + @property + def default_server_hosts(self) -> list[WeegeeHost]: + return [WeegeeHost(self.context, item) for item in self.item.config['default_server_hosts']] + + @default_server_hosts.setter + def default_server_hosts(self, value: list[WeegeeHost]) -> None: + self.item.config['default_server_hosts'] = [x.item for x in value] + + @property + def default_client_hosts(self) -> list[WeegeeHost]: + return [WeegeeHost(self.context, item) for item in self.item.config['default_client_hosts']] + + @default_client_hosts.setter + def default_client_hosts(self, value: list[WeegeeHost]) -> None: + self.item.config['default_client_hosts'] = [x.item for x in value] + + +@dataclass +class WeegeeContext: + LOCAL_HOST_NAME = 'local' + instance: Instance + + def setup(self) -> None: + logger.info('setup: metas') + for wmeta in (WEEGEE_HOST, WEEGEE_SERVER, WEEGEE_CLIENT, WEEGEE_CONFIG): + logger.info(' ' + wmeta.name) + Meta(Config.parse(self.instance, wmeta.get_name(), wmeta.spec)).save() + logger.info('setup: templates') + for wtemp in (WEEGEE_SERVER_CONF, WEEGEE_CLIENT_CONF): + logger.info(' ' + wtemp.name) + Template(wtemp.get_name(), wtemp.template, self.instance).save() + + logger.info('setup: items') + logger.info(f' {WeegeeHost.get_name(self.LOCAL_HOST_NAME)}') + localhost = WeegeeHost.create(self, self.LOCAL_HOST_NAME) + localhost.save() + + logger.info('setup: config') + config = self.get_config() + config.default_server_hosts = [localhost] + config.save() + + def get_local_host(self) -> WeegeeHost: + return WeegeeHost.load(self, WeegeeHost.get_name(self.LOCAL_HOST_NAME)) + + def get_config(self) -> WeegeeConfig: + return WeegeeConfig.load(self, WeegeeConfig.get_name()) + + def get_servers(self) -> list[WeegeeServer]: + return [WeegeeServer.load(self, x) for x in Item.filter(self.instance, WeegeeServer.get_name('*'))] + + def get_server(self, name: str) -> O[WeegeeServer]: + return WeegeeServer.load(self, WeegeeServer.get_name(name)) + + def get_hosts(self) -> list[WeegeeHost]: + return [WeegeeHost.load(self, x) for x in Item.filter(self.instance, WeegeeHost.get_name('*'))] + + def get_host(self, name: str) -> O[WeegeeHost]: + return WeegeeServer.load(self, WeegeeHost.get_name(name)) diff --git a/weegee/__main__.py b/weegee/__main__.py new file mode 100644 index 0000000..607fd86 --- /dev/null +++ b/weegee/__main__.py @@ -0,0 +1,184 @@ +import sys +import argparse +import logging +from typing import Optional as O + +logging.basicConfig(level=logging.DEBUG) + +from .dazy import Instance +from . import WeegeeContext, WeegeeHost, WeegeeServer, WeegeeClient + + +def main(): + parser = argparse.ArgumentParser() + parser.set_defaults(func=None) + parser.add_argument('-d', '--base-dir', default='.') + parser.add_argument('--yes-i-want-to-destroy-this', action='store_true', default=False) + + commands = parser.add_subparsers(title='commands') + + def do_status(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + for server in ctx.get_servers(): + print(f'{server.name}:') + for host in server.hosts: + print(f' @ {host.name}:') + peers = {p.name: p for p in host.get_peers(server.interface)} + for client in server.get_clients(): + if client.public_key not in peers: + print(f' {client.name}: client not found in peer list!') + else: + peer = peers[client.public_key] + tx, rx = peer.get_total_transfer() + endpoint = peer.last_endpoint() + print(f' {client.name}: last handshake {peer.last_handshake() or "never"}{" from " + endpoint if endpoint else ""}, {tx} sent, {rx} received') + + status = commands.add_parser('status') + status.set_defaults(func=do_status) + + def do_sync(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + log = set() + for server in ctx.get_servers(): + server.sync(log=log) + + sync = commands.add_parser('sync') + sync.set_defaults(func=do_sync) + + + # System commands + + system = commands.add_parser('system') + system_commands = system.add_subparsers(title='system commands') + + def do_setup(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + ctx.setup() + + setup = system_commands.add_parser('setup') + setup.set_defaults(func=do_setup) + + def do_configure(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + config = ctx.get_config() + config.autosync = args.autosync + config.automanage = args.automanage + config.save() + + configure = system_commands.add_parser('configure') + configure.set_defaults(func=do_configure) + configure.add_argument('--autosync', action='store_true', default=False) + configure.add_argument('--automanage', action='store_true', default=False) + + def do_migrate(arser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + pass + migrate = system_commands.add_parser('migrate') + migrate.set_defaults(func=do_migrate) + + + # Host commands + + host = commands.add_parser('host') + host_commands = host.add_subparsers(title='host commands') + + def do_add_host(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + host = WeegeeHost.create(ctx, + args.name, host=args.host, user=args.user, elevate_user=args.elevate_user, + autosync=args.auto_sync, automanage=args.auto_manage, + ) + host.save() + + add_host = host_commands.add_parser('create') + add_host.add_argument('-H', '--host', help='remote host') + add_host.add_argument('-u', '--user', help='username') + add_host.add_argument('-U', '--elevate_user', help='username to elevate privileges to') + add_host.add_argument('-a', '--auto-sync', action='store_true', default=False, help='whether to auto-synchronize config') + add_host.add_argument('-A', '--auto-manage', action='store_true', default=False, help='whether to auto-synchronize interfaces') + add_host.add_argument('name', help='host name') + add_host.set_defaults(func=do_add_host) + + # Server commands + + server = commands.add_parser('server') + server_commands = server.add_subparsers(title='server commands') + + def do_add_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + interface = args.interface or f'wg-{args.name}' + server = WeegeeServer.create(ctx, + args.name, interface, (args.host, args.port), + private_key=args.private_key, public_key=args.public_key, + addresses=args.address, routed_addresses=args.routed_address, + ) + server.save() + + add_server = server_commands.add_parser('create') + add_server.add_argument('-a', '--address', action='append', help='interface address(es)') + add_server.add_argument('-A', '--routed-address', action='append', help='routed address(es)') + add_server.add_argument('-k', '--public-key', help='public key (optional)') + add_server.add_argument('-K', '--private-key', help='private key (optional)') + add_server.add_argument('-i', '--interface', help='interface name (optional)') + add_server.add_argument('name', help='server name') + add_server.add_argument('host', help='endpoint host') + add_server.add_argument('port', type=int, help='listen port') + add_server.set_defaults(func=do_add_server) + + def do_del_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + if not args.yes_i_want_to_destroy_this: + parser.error('please pass --yes-i-want-to-destroy-this if you really want to remove this server') + del_server = server_commands.add_parser('destroy') + del_server.add_argument('name', help='server name') + del_server.set_defaults(func=do_del_server) + + def do_conf_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + server = ctx.get_server(args.name) + print(server.gen_config()) + + conf_server = server_commands.add_parser('config') + conf_server.add_argument('name', help='server name') + conf_server.set_defaults(func=do_conf_server) + + + # Client commands + + client = commands.add_parser('client') + client_commands = client.add_subparsers(title='client commands') + + def do_add_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + server = ctx.get_server(args.server) + client = WeegeeClient.create(ctx, server, + args.name, + private_key=args.private_key, public_key=args.public_key, preshared_key=args.preshared_key, + addresses=args.address, + ) + client.save() + + add_client = client_commands.add_parser('create') + add_client.add_argument('-a', '--address', action='append', help='interface address(es)') + add_client.add_argument('-k', '--public-key', help='public key (optional)') + add_client.add_argument('-K', '--private-key', help='private key (optional)') + add_client.add_argument('-p', '--preshared-key', help='preshared key (optional)') + add_client.add_argument('name', help='client name') + add_client.add_argument('server', help='server name') + add_client.set_defaults(func=do_add_client) + + def do_del_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + pass + del_client = client_commands.add_parser('destroy') + del_client.set_defaults(func=do_del_client) + + def do_conf_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]: + server = ctx.get_server(args.server) + client = server.get_client(args.name) + print(client.gen_config()) + + conf_client = client_commands.add_parser('config') + conf_client.add_argument('name', help='name') + conf_client.add_argument('server', help='server name') + conf_client.set_defaults(func=do_conf_client) + + + args = parser.parse_args() + if not args.func: + parser.error('a subcommand must be provided') + + instance = Instance(args.base_dir) + sys.exit(args.func(parser, args, WeegeeContext(instance))) + + +main() diff --git a/weegee/dazy.py b/weegee/dazy.py new file mode 100644 index 0000000..3972b6f --- /dev/null +++ b/weegee/dazy.py @@ -0,0 +1,634 @@ +import os +from enum import Enum +from dataclasses import dataclass +from collections import UserString +from ast import literal_eval +import ipaddress +from fnmatch import fnmatch +import jinja2 +import jinja2.meta +from types import SimpleNamespace + +from io import StringIO +from typing import Optional as O, Union as U, Any, Iterable, Iterator, TypeVar, Generic + + + +def stripsplit(s: str, sep: str, maxsplit: int = -1) -> list[str]: + return [x.strip() for x in s.split(sep, maxsplit=maxsplit)] + +def quotesplit(s: str, sep: str, maxsplit: int = - 1) -> list[str]: + parts = [] + pos = -1 + if maxsplit < 0: + maxsplit = len(s) + + while len(parts) < maxsplit: + npos = s.find(sep, pos + 1) + if npos < 0: + break + pos = npos + if pos > 0 and s[pos - 1] == '\\': + s = s[:pos - 1] + s[pos:] + continue + parts.append(s[:pos]) + s = s[pos + 1:] + pos = -1 + + parts.append(s) + return parts + +def quotejoin(sep: str, a: list[str]) -> str: + return sep.join(x.replace(sep, '\\' + sep) for x in a) + +def indent(s, amount=0, skip=0) -> int: + spacing = ' ' * amount + return '\n'.join((spacing if i >= skip else '') + l for i, l in enumerate(s.splitlines())) + + +class Instance: + def __init__(self, conf_dir: str, data_dir: O[str] = None) -> None: + self.conf_dir = conf_dir + self.data_dir = data_dir or conf_dir + + +class Int(int): + @classmethod + def parse(cls, value: str) -> 'Int': + return cls(value) + +class Str(UserString): + @classmethod + def parse(cls, value: str) -> 'Str': + val = literal_eval(value) + if not isinstance(val, str): + raise TypeError(value) + return cls(val) + +@dataclass +class IPAddress: + address: U[ipaddress.IPv4Address, ipaddress.IPv6Address] + + @classmethod + def parse(cls, s: str) -> 'IPAddress': + return cls(ipaddress.ip_address(s)) + + def __getattr__(self, name: str) -> Any: + return getattr(self.address, name) + + def __repr__(self) -> str: + return str(self.address) + +@dataclass +class IPNetwork: + network: U[ipaddress.IPv4Network, ipaddress.IPv6Network] + + @classmethod + def parse(cls, s: str) -> 'IPNetwork': + return cls(ipaddress.ip_network(s)) + + def __getattr__(self, name: str) -> Any: + return getattr(self.network, name) + + def __repr__(self) -> str: + return str(self.network) + +@dataclass +class IPInterface: + interface: U[ipaddress.IPv4Interface, ipaddress.IPv6Interface] + + @classmethod + def parse(cls, s: str) -> 'IPInterface': + return cls(ipaddress.ip_interface(s)) + + def __getattr__(self, name: str) -> Any: + return getattr(self.interface, name) + + def __repr__(self) -> str: + return str(self.interface) + + +T = TypeVar('T') + +class Type(Generic[T]): + @classmethod + def parse(cls, input: str) -> 'Type': + if input.startswith('?'): + return OptionalType.parse(input) + if input.startswith('['): + return ArrType.parse(input) + if input.startswith('{'): + return MapType.parse(input) + if input.startswith('@'): + return RefType.parse(input) + if input.startswith('*'): + return GlobType.parse(input) + return LitType.parse(input) + + def dump(self) -> str: + raise NotImplementedError + + def match_value(self, input: Any, instance: Instance) -> bool: + return False + + def parse_value(self, input: str, instance: Instance) -> T: + raise TypeError + + def dump_value(self, input: T) -> str: + return repr(input) + + def template_value(self, input: T) -> T: + return input + +@dataclass +class OptionalType(Generic[T], Type[O[T]]): + subtype: Type[T] + + @classmethod + def parse(cls, input: str) -> 'OptionalType[T]': + if input[0] != '?': + raise ParseError() + return cls(Type.parse(input[1:])) + + def dump(self) -> str: + return f'?{self.subtype.dump()}' + + def match_value(self, input: Any, instance: Instance) -> bool: + return input is None or self.subtype.match_value(input, instance) + + def parse_value(self, input: str, instance: Instance) -> O[T]: + if not input: + return None + return self.subtype.parse_value(input, instance) + + def dump_value(self, input: O[T]) -> str: + if input is None: + return '' + return self.subtype.dump_value(input) + + def template_value(self, input: O[T]) -> O[T]: + if input is None: + return None + return self.subtype.template_value(input) + +@dataclass +class RefType(Type['Item']): + name: str + + @classmethod + def parse(cls, input: str) -> 'RefType': + if input[0] != '@': + raise ParseError() + return cls(input[1:]) + + def dump(self) -> str: + return f'@{self.name}' + + def match_value(self, input: Any, instance: Instance) -> bool: + meta = Meta.load(instance, self.name) + return isinstance(input, Item) and input.is_complete(meta) + + def parse_value(self, input: str, instance: Instance) -> 'Item': + meta = Meta.load(instance, self.name) + return Item.load(instance, input) + + def dump_value(self, input: 'Item') -> str: + return input.name + + def template_value(self, input: 'Item') -> 'Config': + return input.config.template() + +@dataclass +class GlobType(Type[dict[str, 'Item']]): + name: str + + @classmethod + def parse(cls, input: str) -> 'GlobType': + if input[0] != '*': + raise ParseError() + return cls(input[1:]) + + def dump(self) -> str: + s = f'*{self.name}' + return s + + def match_value(self, input: Any, instance: Instance) -> bool: + meta = Meta.load(instance, self.name) + return isinstance(input, list) and all(isinstance(x, Item) and fnmatch(x.name, input) and x.is_complete(meta) for x in input) + + def parse_value(self, input: str, instance: Instance) -> dict[str, 'Item']: + meta = Meta.load(instance, self.name) + out = {} + for x in Item.filter(instance, input): + item = Item.load(instance, x) + if not item.is_complete(meta): + continue + out[x] = item + return out + + def dump_value(self, input: dict[str, 'Item']) -> str: + return '' + + def template_value(self, input: dict[str, 'Item']) -> dict[str, 'Item']: + return {x.name: x.config.template() for x in input} + +@dataclass +class LitType(Generic[T], Type[T]): + SUBTYPES = { + 'int': Int, + 'str': Str, + 'ipaddr': IPAddress, + 'ipnet': IPNetwork, + 'ipintf': IPInterface, + } + name: str + subtype: T + + @classmethod + def parse(cls, input: str) -> 'LitType[T]': + if input not in cls.SUBTYPES: + raise ParseError() + return cls(name=input, subtype=cls.SUBTYPES[input]) + + def dump(self) -> str: + return self.name + + def match_value(self, input: Any, instance: Instance) -> bool: + return isinstance(input, self.subtype) + + def parse_value(self, input: str, instance: Instance) -> T: + return self.subtype.parse(input) + + def dump_value(self, input: T) -> str: + return repr(input) + +@dataclass +class ArrType(Generic[T], Type[list[T]]): + subtype: Type[T] + + @classmethod + def parse(cls, input: str) -> 'ArrType': + if input[0] != '[' or input[-1] != ']': + raise ParseError() + return cls(subtype=Type.parse(input[1:-1])) + + def dump(self) -> str: + return f'[{self.subtype.dump()}]' + + def match_value(self, input: Any, instance: Instance) -> bool: + return isinstance(input, list) and all(self.subtype.match_value(x, instance) for x in input) + + def parse_value(self, input: str, instance: Instance) -> list[T]: + if input[0] != '[' or input[-1] != ']': + raise ParseError() + input = input[1:-1].strip() + if not input: + return [] + return [self.subtype.parse_value(x.strip(), instance) for x in quotesplit(input, ',')] + + def dump_value(self, input: list[T]) -> str: + return '[' + quotejoin(', ', (self.subtype.dump_value(x) for x in input)) + ']' + + def template_value(self, input: list[T]) -> list[T]: + return [self.subtype.template_value(x) for x in input] + +V = TypeVar('V') + +@dataclass +class MapType(Generic[T, V], Type[dict[T, V]]): + keytype: Type[T] + valtype: Type[V] + + @classmethod + def parse(cls, input: str) -> 'MapType[T, V]': + if input[0] != '{' or input[-1] != '}': + raise ParseError() + kt, vt = quotesplit(input[1:-1], ':', maxsplit=1) + return cls(keytype=Type.parse(kt.strip()), valtype=Type.parse(vt.strip())) + + def dump(self) -> str: + kt = self.keytype.dump().replace(':', '\\:') + vt = self.valtype.dump() + return f'{{{kt}: {vt}}}' + + def match_value(self, input: Any, instance: Instance) -> bool: + return isinstance(input, dict) and all(self.keytype.match_value(k, instance) and self.valtype.match_value(v, instance) for (k, v) in input.items()) + + def parse_value(self, input: str, instance: Instance) -> tuple[dict[T, V], str]: + if input[0] != '{' or input[-1] != '}': + raise ParseError() + out = {} + input[1:-1].strip() + if input: + items = quotesplit(input, ',') + for item in items: + k, v = quotesplit(item, ':', maxsplit=1) + out[self.keytype.parse_value(k.strip(), instance)] = self.valtype.parse_value(v.strip(), instance) + return out + + def dump_value(self, input: dict[T, V]) -> str: + return '{' + quotejoin(', ', (self.keytype.dump_value(k).replace(':', '\\:') + ': ' + self.valtype.dump_value(v) for (k, v) in input.items())) + '}' + + def template_value(self, input: dict[T, V]) -> dict[T, V]: + return {self.keytype.template_value(k): self.valtype.template_value(v) for (k, v) in input.items()} + + +NoValue = object() + +@dataclass +class Config: + name: str + resolved: dict[str, tuple[Type, O[Any]]] + unresolved: dict[str, str] + instance: 'Instance' + kind: O[str] = None + + @classmethod + def parse(cls, instance: Instance, name: str, lines: Iterable[str]) -> 'Config': + resolved = {} + unresolved = {} + for line in lines: + line, *_ = stripsplit(line, '#', maxsplit=1) + line, *parts = stripsplit(line, '=', maxsplit=1) + valstr = parts[0] if parts else None + vname, *parts = stripsplit(line, ':', maxsplit=1) + typestr = parts[0] if parts else None + if typestr is not None: + vtype = Type.parse(typestr) + if valstr is not None: + vval = vtype.parse_value(valstr, instance) + else: + vval = NoValue + resolved[vname] = (vtype, vval) + else: + unresolved[vname] = valstr + + return cls(name=name, resolved=resolved, unresolved=unresolved, instance=instance) + + @classmethod + def make(cls, _instance: Instance, _name: str, **args: dict[str, tuple[Type, O[Any]]]) -> 'Config': + return cls(name=_name, resolved=args, unresolved={}, instance=_instance) + + def dump(self) -> list[str]: + out = [] + for name, (vtype, vval) in self.resolved.items(): + line = f'{name}: {vtype.dump()}' + if vval is not NoValue: + line += f' = {vtype.dump_value(vval)}' + out.append(line) + for name, vval in self.unresolved.items(): + line = f'{name} = {vval}' + out.append(line) + return out + + def template(self) -> SimpleNamespace: + resolved = {} + for name, (vtype, vval) in self.resolved.items(): + resolved[name] = vtype.template_value(vval) if vval is not NoValue else vval + return SimpleNamespace(**resolved) + + def matches(self, other: 'Config') -> bool: + for vname, (vtype, vval) in other.resolved.items(): + if vname in self.unresolved: + vval = vtype.parse_value(self.unresolved[vname], self.instance) + elif vname in self.resolved: + vval = self[vname] + if vval is NoValue: + return False + return True + + + @property + def type_complete(self) -> bool: + return not self.unresolved + + @property + def value_complete(self) -> bool: + return self.type_complete and all(vval is not NoValue for _, vval in self.resolved.values()) + + + def __getitem__(self, name: str) -> Any: + _, value = self.resolved[name] + if value is NoValue: + raise KeyError(name) + return value + + def __setitem__(self, name: str, value: Any) -> None: + vtype, _ = self.resolved[name] + if not vtype.match_value(value, self.instance): + raise TypeError(f'{name}: {value} must match {vtype.dump()}') + self.resolved[name] = (vtype, value) + + def __contains__(self, name: str) -> bool: + if name not in self.resolved: + return False + _, value = self.resolved[name] + return value is not NoValue + + def __iter__(self) -> Iterator[str]: + return iter(self.resolved) + + def __repr__(self) -> str: + value = self.name + if self.kind: + value += f': {self.kind}' + value += ' {\n' + for name, (vtype, vval) in self.resolved.items(): + valstr = f' {name}: {vtype.dump()}' + if vval is not NoValue: + valstr += f' = {indent(str(vval), amount=2, skip=1)}' + value += valstr + '\n' + for name, vval in self.unresolved.items(): + value += f' {name}: ?{indent(str(vval), amount=2, skip=1)}\n' + value += '}' + return value + + +@dataclass +class Meta: + config: Config + + @classmethod + def load(cls, instance: 'Instance', name: str) -> 'Meta': + with open(f'{instance.conf_dir}/configs/{name}.conf', 'r') as f: + config = Config.parse(instance, name, f.readlines()) + return cls(config=config) + + def save(self) -> None: + path = f'{self.config.instance.conf_dir}/configs/{self.name}.conf' + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write('\n'.join(self.config.dump()) + '\n') + + @property + def name(self) -> str: + return self.config.name + + @property + def is_complete(self) -> bool: + return self.config.type_complete + + def check_complete(self) -> None: + if not self.is_complete: + raise ValueError(f'Incomplete config for meta: {self.config}') + +@dataclass +class Item: + config: Config + + @classmethod + def filter(cls, instance: Instance, pattern: str) -> str: + basedir = f'{instance.data_dir}/items/' + allfiles = [] + for root, _, files in os.walk(basedir): + for f in files: + if not f.endswith('.conf'): + continue + base = os.path.join(root[len(basedir):], f[:-len('.conf')]) + if not fnmatch(base, pattern): + continue + allfiles.append(base) + return allfiles + + @classmethod + def load(cls, instance: Instance, name: str) -> None: + with open(f'{instance.data_dir}/items/{name}.conf', 'r') as f: + config = Config.parse(instance, name, f) + return cls(config=config) + + def save(self) -> None: + path = f'{self.config.instance.data_dir}/items/{self.name}.conf' + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write('\n'.join(self.config.dump()) + '\n') + + @property + def name(self) -> str: + return self.config.name + + def resolve(self, meta: Meta) -> O['Item']: + c = Config(name=self.name, resolved={}, unresolved={}, instance=self.config.instance, kind=meta.name) + for name, (vtype, vval) in meta.config.resolved.items(): + c.resolved[name] = (vtype, NoValue) + if name in self.config.resolved: + vval = self.config[name] + elif name in self.config.unresolved: + vval = vtype.parse_value(self.config.unresolved[name], self.config.instance) + if vval is not NoValue: + c[name] = vval + return self.__class__(c) + + def is_complete(self, meta: Meta) -> bool: + return self.resolve(meta).config.value_complete + + def check_complete(self, meta: Meta) -> None: + if not self.is_complete(meta): + raise ValueError(f'incomplete config {self.config.name} for {meta.name}: {self.config}') + + def __str__(self) -> str: + return str(self.config) + +@dataclass +class Template: + name: str + source: str + instance: Instance + template: O[jinja2.Template] = None + + def __post_init__(self) -> None: + if not self.template: + self.template = jinja2.Template(self.source) + + @classmethod + def load(cls, instance: Instance, name: str) -> None: + env = jinja2.Environment() + with open(f'{instance.conf_dir}/templates/{name}.tmpl', 'r') as f: + source = f.read() + return cls(name=name, source=source, instance=instance) + + def save(self) -> None: + path = f'{self.instance.conf_dir}/templates/{self.name}.tmpl' + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write(self.source) + + def get_variables(self) -> set[str]: + return jinja2.meta.find_undeclared_variables(jinja2.Environment().parse(self.source)) + + def render(self, config: Config) -> str: + templated = config.template() + return self.template.render(**templated.__dict__) + + +if __name__ == '__main__': + import sys + import argparse + + parser = argparse.ArgumentParser() + parser.set_defaults(func=None) + parser.add_argument('-d', '--base-dir', default='.', help='Base directory for configuration') + + commands = parser.add_subparsers(title='commands') + + def split_meta(name: str, instance: Instance) -> tuple[str, O[Meta]]: + if ':' in name: + name, metaname = name.split(':', maxsplit=1) + meta = Meta.load(instance, metaname) + else: + meta = None + return name, meta + + + def do_new_meta(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None: + meta = Meta(config=Config.parse(instance, args.name, args.spec)) + meta.save() + + new_meta = commands.add_parser('new-meta') + new_meta.add_argument('name') + new_meta.add_argument('spec', nargs='*') + new_meta.set_defaults(func=do_new_meta) + + def do_new_item(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None: + name, meta = split_meta(args.name, instance) + item = Item(config=Config.parse(instance, name, args.spec)) + if meta: + meta.check_complete() + item.check_complete(meta) + item.save() + + new_item = commands.add_parser('new') + new_item.add_argument('name') + new_item.add_argument('spec', nargs='*') + new_item.set_defaults(func=do_new_item) + + def do_check_item(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None: + name, meta = split_meta(args.name, instance) + item = Item.load(instance, name) + if meta: + meta.check_complete() + return 0 if item.is_complete(meta) else 1 + else: + return 0 + + check_item = commands.add_parser('check') + check_item.add_argument('name') + check_item.set_defaults(func=do_check_item) + + def do_template(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None: + template = Template.load(instance, args.name) + config = Config.parse(instance, '', args.spec) + missing = template.get_variables() - set(config) + if missing: + parser.error(f'missing variables: {", ".join(missing)}') + print(template.render(config)) + + template_item = commands.add_parser('template') + template_item.add_argument('name') + template_item.add_argument('spec', nargs='*') + template_item.set_defaults(func=do_template) + + args = parser.parse_args() + if not args.func: + parser.error('a subcommand must be provided') + instance = Instance(args.base_dir) + sys.exit(args.func(parser, args, instance)) diff --git a/weegee/desc.py b/weegee/desc.py new file mode 100644 index 0000000..750a911 --- /dev/null +++ b/weegee/desc.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from typing import Union as U + +from .dazy import Instance, Config, RefType, ArrType, Meta + + +@dataclass +class WeegeeDesc: + name: str + version: int + + def get_name(self) -> str: + return f'{self.name}@{self.version}' + +@dataclass +class WeegeeMeta(WeegeeDesc): + spec: list[str] + item_prefix: str = '' + + +WEEGEE_HOST = WeegeeMeta( + name='wg/host', + version=1, + spec=[ + 'autosync: int = 0', + 'automanage: int = 0', + 'host: ?str = ', + 'user: ?str = ', + 'elevate_user: ?str = ', + ], + item_prefix='wg/host', +) +WEEGEE_SERVER = WeegeeMeta( + name='wg/server', + version=1, + spec=[ + f'hosts: [@{WEEGEE_HOST.get_name()}]', + 'interface: str', + 'public_key: str', + 'private_key: str', + 'addresses: [ipintf]', + 'routed_addresses: [ipintf]', + 'host: str', + 'port: int', + ], + item_prefix='wg/server', +) +WEEGEE_CLIENT = WeegeeMeta( + name='wg/client', + version=1, + spec=[ + f'hosts: [@{WEEGEE_HOST.get_name()}]', + f'server: @{WEEGEE_SERVER.get_name()}', + 'interface: str = "wg0"', + 'public_key: str', + 'private_key: str', + 'preshared_key: str', + 'addresses: [ipintf]', + ], + item_prefix='wg/client', +) +WEEGEE_CONFIG = WeegeeMeta( + name='wg/config', + version=1, + spec=[ + f'default_server_hosts: [@{WEEGEE_HOST.get_name()}] = []', + f'default_client_hosts: [@{WEEGEE_HOST.get_name()}] = []', + ], +) + + +@dataclass +class WeegeeTemplate(WeegeeDesc): + template: str + variables: dict[str, U[WeegeeMeta, list[WeegeeMeta]]] + + def make_config(self, instance: Instance, **kwargs) -> 'Config': + args = {} + for k, v in kwargs.items(): + vtype = self.variables[k] + if isinstance(vtype, list): + tname = vtype[0].get_name() + ttype = ArrType(RefType(tname)) + vval = [x.resolve(Meta.load(instance, tname)) for x in v] + else: + tname = vtype.get_name() + ttype = RefType(tname) + vval = v.resolve(Meta.load(instance, tname)) + args[k] = (ttype, vval) + return Config.make(instance, '', **args) + +WEEGEE_SERVER_CONF = WeegeeTemplate( + name='wg/server-conf', + version=1, + template=""" +[Interface] +Address = {{server.addresses | join(', ')}} +ListenPort = {{server.port}} +PrivateKey = {{server.private_key}} + +PostUp = iptables -A FORWARD -i %i -j ACCEPT +PostDown = iptables -D FORWARD -i %i -j ACCEPT + +{% for client in clients -%} +[Peer] +# Client: {{client.name}} +PublicKey = {{client.public_key}} +PresharedKey = {{client.preshared_key}} +AllowedIPs = {{client.addresses | join(', ')}} + +{% endfor %} +""".strip(), + variables={'server': WEEGEE_SERVER, 'clients': [WEEGEE_CLIENT]}, +) + +WEEGEE_CLIENT_CONF = WeegeeTemplate( + name='wg/client-conf', + version=1, + template=""" +[Interface] +PrivateKey = {{client.private_ke}} +Address = {{client.addresses | join(', ')}} + +[Peer] +PublicKey = {{client.server.public_key}} +PresharedKey = {{client.preshared_key}} +AllowedIPs = {{client.server.routed_addresses | join(', ')}} +Endpoint = {{client.server.host}}:{{client.server.port}} +PersistentKeepalive = 30 +""".strip(), + variables={'client': WEEGEE_CLIENT}, +) diff --git a/weegee/wireguard.py b/weegee/wireguard.py new file mode 100644 index 0000000..521633f --- /dev/null +++ b/weegee/wireguard.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass +from datetime import datetime +import shlex +import ipaddress +from typing import Optional as O, Union as U +import subprocess +from logging import getLogger + +logger = getLogger(__name__) + + +@dataclass +class WireguardPeer: + conn: 'WireguardConnection' + interface: str + name: str + + def _filter_list(self, cmd: str) -> O[list[str]]: + for line in self.conn._run_wg('show', self.interface, cmd).splitlines(): + peer, *parts = line.split() + if peer == self.name: + return parts + + def last_handshake(self) -> O[datetime]: + parts = self._filter_list('latest-handshakes') + if not parts: + return None + ts = int(parts[0]) + if not ts: + return None + return datetime.fromtimestamp(ts) + + def last_endpoint(self) -> O[tuple[U[ipaddress.IPv4Address, ipaddress.IPv6Address], int]]: + parts = self._filter_list('endpoints') + if not parts: + return None + if parts[0] == '(none)': + return None + addr, port = parts[0].rsplit(':', maxsplit=1) + return (ipaddress.ip_address(addr.strip('[]')), int(port)) + + def get_total_transfer(self) -> O[tuple[int, int]]: + parts = self._filter_list('transfer') + if not parts: + return None + return (int(parts[0]), int(parts[1])) + +@dataclass +class WireguardInterface: + conn: 'WireguardConnection' + name: str + + def list_peers(self) -> list[WireguardPeer]: + return [WireguardPeer(self.conn, self.name, x) for x in self.conn._run_wg('show', self.name, 'peers').split()] + + def get_config(self) -> str: + return self.conn._run_wg('getconf', self.name) + + def set_config(self, config: str) -> None: + config = self.conn._run_wg_quick('strip', self.name, config) + self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin=config) + + def sync_config(self, config: str) -> None: + config = self.conn._run_wg_quick('strip', self.name, config) + self.conn._run_wg('syncconf', self.name, '/dev/stdin', stdin=config) + + def delete(self) -> None: + self.conn._run_wg_quick('down', self.name) + +@dataclass +class WireguardConnection: + host: O[str] + user: O[str] + elevate_user: O[str] + + def _run(self, *args, shell=False, stdin=None, **kwargs) -> str: + cmd = [] + if self.host: + cmd += ['ssh', f'{self.user}@{self.host}' if self.user else str(self.host)] + if self.elevate_user: + cmd += ['sudo', '-i', '-u', self.elevate_user] + if shell: + cmd += ['sh', '-c', shlex.quote('set -e; ' + args[0])] + else: + cmd += list(args) + logger.debug(f'executing: {shlex.join(cmd)}') + p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='ascii', **kwargs) + stdout, _ = p.communicate(stdin) + return stdout.strip() + + def _run_wg(self, *args, **kwargs) -> str: + return self._run('wg', *args, **kwargs) + + def _run_wg_quick(self, command, interface, config=None, **kwargs) -> str: + cmd = '' + stdin = None + if config: + cmd += f'tmpdir="$(mktemp -d)"; tmpfile="$tmpdir/{shlex.quote(str(interface) + ".conf")}"; trap \'rm -rf "$tmpdir"\' EXIT; cat > "$tmpfile"; intf="$tmpfile"; ' + stdin = config + else: + cmd += f'intf={shlex.quote(interface)}; ' + cmd += f'exec wg-quick {shlex.quote(command)} "$intf"' + return self._run(cmd, shell=True, stdin=stdin, **kwargs) + + def list_interfaces(self) -> list[WireguardInterface]: + return [WireguardInterface(self, x) for x in self._run_wg('show', 'interfaces').split()] + + def get_interface(self, name: str) -> O[WireguardInterface]: + for interface in self.list_interfaces(): + if interface.name == name: + return interface + return None + + def create_interface(self, name: str, config: str) -> WireguardInterface: + self._run_wg_quick('up', name, config) + return self.get_interface(name) + + def gen_preshared_key(self) -> str: + return self._run_wg('genpsk') + + def gen_private_key(self) -> str: + return self._run_wg('genkey') + + def get_public_key(self, privkey: str) -> str: + return self._run_wg('pubkey', stdin=privkey)