This commit is contained in:
Shiz 2021-12-05 20:06:26 +01:00
commit 8599ec36e7
7 changed files with 1432 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
__pycache__
*.pyc
/configs
/items
/templates

1
requirements.txt Normal file
View File

@ -0,0 +1 @@
jinja2

350
weegee/__init__.py Normal file
View File

@ -0,0 +1,350 @@
from logging import getLogger
from dataclasses import dataclass
from typing import Optional as O, Any
from .dazy import Instance, Config, Meta, Item, Template
from .wireguard import WireguardConnection, WireguardPeer
from .desc import WEEGEE_SERVER, WEEGEE_CLIENT, WEEGEE_HOST, WEEGEE_CONFIG, WEEGEE_SERVER_CONF, WEEGEE_CLIENT_CONF
logger = getLogger(__name__)
@dataclass
class WeegeeBase:
BASE = None
context: 'WeegeeContext'
item: Item
meta: Meta = None
resolved: Item = None
def __post_init__(self) -> None:
if not self.meta:
self.meta = Meta.load(self.context.instance, self.BASE.get_name())
if not self.resolved:
self.resolved = self.item.resolve(self.meta)
@classmethod
def get_name(cls, name: str) -> str:
return f'{cls.BASE.item_prefix}/{name}'
@classmethod
def get_meta(cls, context: 'WeegeeContext') -> Meta:
return Meta.load(context.instance, cls.BASE.get_name())
@classmethod
def create(cls, context: 'WeegeeContext', config: Config) -> 'WeegeeBase':
meta = cls.get_meta(context)
if not config.matches(meta.config):
raise TypeError('internal error')
return cls(context, Item(config), meta)
@classmethod
def load(cls, context: 'WeegeeContext', name: str) -> 'WeegeeBase':
return cls(context, Item.load(context.instance, name))
@property
def name(self) -> str:
return self.item_name[len(self.get_name('')):]
@property
def item_name(self) -> str:
return self.item.name
def save(self) -> None:
self.item.save()
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.item.name == other.item.name
@dataclass
class WeegeeHost(WeegeeBase):
BASE = WEEGEE_HOST
conn: WireguardConnection = None
@classmethod
def create(cls, ctx: 'WeegeeContext', name: str, host: O[str] = None, user: O[str] = None, elevate_user: O[str] = None, automanage: bool = False, autosync: bool = False) -> 'WeegeeHost':
confname = cls.get_name(name)
return super().create(ctx, Config.parse(ctx.instance, confname, [
f'automanage = {1 if automanage else 0}',
f'autosync = {1 if autosync else 0}',
f'host = {repr(host) if host else ""}',
f'user = {repr(user) if user else ""}',
f'elevate_user = {repr(elevate_user) if elevate_user else ""}',
]))
@property
def automanage(self) -> bool:
return self.resolved.config['automanage'] == 1
@property
def autosync(self) -> bool:
return self.resolved.config['autosync'] == 1
@property
def host(self) -> O[str]:
return self.resolved.config['host']
@property
def user(self) -> O[str]:
return self.resolved.config['user']
@property
def elevate_user(self) -> O[str]:
return self.resolved.config['elevate_user']
def __post_init__(self) -> None:
super().__post_init__()
self.conn = WireguardConnection(self.host, self.user, self.elevate_user)
def get_peers(self, name: str) -> O[list[WireguardPeer]]:
interface = self.conn.get_interface(name)
if not interface:
return None
return interface.list_peers()
def sync_config(self, name: str, config: str) -> None:
interface = self.conn.get_interface(name)
if not interface:
if not self.automanage:
return
interface = self.conn.create_interface(name, config)
else:
interface.sync_config(config)
def remove_config(self, name: str) -> None:
interface = self.conn.get_interface(name)
if interface:
interface.sync_config('')
if self.automanage:
interface.delete()
@dataclass
class WeegeeClient(WeegeeBase):
BASE = WEEGEE_CLIENT
@classmethod
def get_name(cls, server: 'WeegeeServer', name: str) -> str:
return f'{cls.BASE.item_prefix}/{server.resolved.config["name"]}/{name}'
@property
def name(self) -> str:
return self.item_name[len(self.get_name(self.server, '')):]
@classmethod
def create(cls, ctx: 'WeegeeContext', server: 'WeegeeServer', name: str, private_key: O[str] = None, public_key: O[str] = None, preshared_key: O[str] = None, addresses: list[str] = [], hosts: O[list[WeegeeHost]] = None) -> 'WeegeeClient':
hosts = hosts or ctx.get_config().default_client_hosts
for serv in [ctx.get_local_host()] + hosts:
try:
private_key = private_key or serv.conn.gen_private_key()
public_key = public_key or ctx.wg.get_public_key(private_key)
preshared_key = preshared_key or ctx.wg.gen_preshared_key()
break
except Exception as e:
logger.warn(f'could not generate keypair on host {serv.item_name}: {e!r}')
else:
logger.critical('could not generate public/private keypair automatically: please pass explicitly!')
raise ValueError()
confname = cls.get_name(server, name)
config = Config.parse(ctx.instance, confname, [
f'hosts = [{", ".join(h.item_name for h in hosts)}]',
f'server = {server.item_name}',
f'name = {name!r}',
f'public_key = {public_key!r}',
f'private_key = {private_key!r}',
f'preshared_key = {preshared_key!r}',
f'addresses = [{", ".join(addresses)}]',
])
return super().create(ctx, config)
@property
def interface(self) -> str:
return self.resolved.config['interface']
@property
def public_key(self) -> str:
return self.resolved.config['public_key']
@property
def hosts(self) -> list[WeegeeHost]:
return [WeegeeHost(self.context, x) for x in self.resolved.config['hosts']]
@property
def server(self) -> 'WeegeeServer':
return WeegeeServer(self.context, self.resolved.config['server'])
def save(self) -> None:
super().save()
self.sync(auto=True)
def sync(self, auto=False, log=None) -> None:
if not log:
log = set()
if self.item_name in log:
return
log.add(self.item_name)
logger.info(f'syncing: {self.item_name}')
for host in self.hosts:
if auto and not host.autosync:
continue
host.sync_config(self.interface, self.gen_config())
self.server.sync(auto=auto, log=log)
def gen_config(self) -> str:
template = Template.load(self.context.instance, WEEGEE_CLIENT_CONF.get_name())
config = WEEGEE_CLIENT_CONF.make_config(self.context.instance,
client=self.item,
)
return template.render(config)
@dataclass
class WeegeeServer(WeegeeBase):
BASE = WEEGEE_SERVER
def get_clients(self) -> list[WeegeeClient]:
return [WeegeeClient.load(self.context, x) for x in Item.filter(self.context.instance, WeegeeClient.get_name(self, '*'))]
def get_client(self, name: str) -> O[WeegeeClient]:
return WeegeeClient.load(self.context, WeegeeClient.get_name(self, name))
@classmethod
def create(cls, ctx: 'WeegeeContext', name: str, interface: str, host: tuple[str, int], private_key: O[str] = None, public_key: O[str] = None, addresses: list[str] = [], routed_addresses: list[str] = [], hosts: O[list[WeegeeHost]] = None) -> 'WeegeeServer':
hosts = hosts or ctx.get_config().default_server_hosts
for serv in [ctx.get_local_host()] + hosts:
try:
private_key = private_key or serv.conn.gen_private_key()
public_key = public_key or serv.conn.get_public_key(private_key)
break
except Exception as e:
logger.warn(f'could not generate keypair on host {serv.item_name}: {e!r}')
else:
logger.critical('could not generate public/private keypair automatically: please pass explicitly!')
raise ValueError()
confname = cls.get_name(name)
config = Config.parse(ctx.instance, confname, [
f'hosts = [{", ".join(h.item_name for h in hosts)}]',
f'name = {name!r}',
f'interface = {interface!r}',
f'public_key = {public_key!r}',
f'private_key = {private_key!r}',
f'addresses = [{", ".join(addresses)}]',
f'routed_addresses = [{", ".join(routed_addresses)}]',
f'host = {host[0]!r}',
f'port = {host[1]}',
])
return super().create(ctx, config)
@property
def interface(self):
return self.resolved.config['interface']
@property
def hosts(self) -> list[WeegeeHost]:
return [WeegeeHost(self.context, x) for x in self.resolved.config['hosts']]
def save(self) -> None:
super().save()
self.sync(auto=True)
def sync(self, auto=False, log=None) -> None:
if not log:
log = set()
if self.item_name in log:
return
log.add(self.item_name)
logger.info(f'syncing: {self.item_name}')
for host in self.hosts:
if auto and not host.autosync:
continue
host.sync_config(self.interface, self.gen_config())
for client in self.get_clients():
client.sync(auto=auto, log=log)
def gen_config(self) -> str:
clients = self.get_clients()
template = Template.load(self.context.instance, WEEGEE_SERVER_CONF.get_name())
config = WEEGEE_SERVER_CONF.make_config(self.context.instance,
server=self.item,
clients=[x.item for x in clients],
)
return template.render(config)
@dataclass
class WeegeeConfig:
context: 'WeegeeContext'
item: Item
@classmethod
def get_name(cls) -> str:
return WEEGEE_CONFIG.get_name()
@classmethod
def load(cls, context: 'WeegeeContext', name: str) -> 'WeegeeConfig':
return cls(context, Meta.load(context.instance, name))
def save(self) -> None:
self.item.save()
@property
def default_server_hosts(self) -> list[WeegeeHost]:
return [WeegeeHost(self.context, item) for item in self.item.config['default_server_hosts']]
@default_server_hosts.setter
def default_server_hosts(self, value: list[WeegeeHost]) -> None:
self.item.config['default_server_hosts'] = [x.item for x in value]
@property
def default_client_hosts(self) -> list[WeegeeHost]:
return [WeegeeHost(self.context, item) for item in self.item.config['default_client_hosts']]
@default_client_hosts.setter
def default_client_hosts(self, value: list[WeegeeHost]) -> None:
self.item.config['default_client_hosts'] = [x.item for x in value]
@dataclass
class WeegeeContext:
LOCAL_HOST_NAME = 'local'
instance: Instance
def setup(self) -> None:
logger.info('setup: metas')
for wmeta in (WEEGEE_HOST, WEEGEE_SERVER, WEEGEE_CLIENT, WEEGEE_CONFIG):
logger.info(' ' + wmeta.name)
Meta(Config.parse(self.instance, wmeta.get_name(), wmeta.spec)).save()
logger.info('setup: templates')
for wtemp in (WEEGEE_SERVER_CONF, WEEGEE_CLIENT_CONF):
logger.info(' ' + wtemp.name)
Template(wtemp.get_name(), wtemp.template, self.instance).save()
logger.info('setup: items')
logger.info(f' {WeegeeHost.get_name(self.LOCAL_HOST_NAME)}')
localhost = WeegeeHost.create(self, self.LOCAL_HOST_NAME)
localhost.save()
logger.info('setup: config')
config = self.get_config()
config.default_server_hosts = [localhost]
config.save()
def get_local_host(self) -> WeegeeHost:
return WeegeeHost.load(self, WeegeeHost.get_name(self.LOCAL_HOST_NAME))
def get_config(self) -> WeegeeConfig:
return WeegeeConfig.load(self, WeegeeConfig.get_name())
def get_servers(self) -> list[WeegeeServer]:
return [WeegeeServer.load(self, x) for x in Item.filter(self.instance, WeegeeServer.get_name('*'))]
def get_server(self, name: str) -> O[WeegeeServer]:
return WeegeeServer.load(self, WeegeeServer.get_name(name))
def get_hosts(self) -> list[WeegeeHost]:
return [WeegeeHost.load(self, x) for x in Item.filter(self.instance, WeegeeHost.get_name('*'))]
def get_host(self, name: str) -> O[WeegeeHost]:
return WeegeeServer.load(self, WeegeeHost.get_name(name))

184
weegee/__main__.py Normal file
View File

@ -0,0 +1,184 @@
import sys
import argparse
import logging
from typing import Optional as O
logging.basicConfig(level=logging.DEBUG)
from .dazy import Instance
from . import WeegeeContext, WeegeeHost, WeegeeServer, WeegeeClient
def main():
parser = argparse.ArgumentParser()
parser.set_defaults(func=None)
parser.add_argument('-d', '--base-dir', default='.')
parser.add_argument('--yes-i-want-to-destroy-this', action='store_true', default=False)
commands = parser.add_subparsers(title='commands')
def do_status(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
for server in ctx.get_servers():
print(f'{server.name}:')
for host in server.hosts:
print(f' @ {host.name}:')
peers = {p.name: p for p in host.get_peers(server.interface)}
for client in server.get_clients():
if client.public_key not in peers:
print(f' {client.name}: client not found in peer list!')
else:
peer = peers[client.public_key]
tx, rx = peer.get_total_transfer()
endpoint = peer.last_endpoint()
print(f' {client.name}: last handshake {peer.last_handshake() or "never"}{" from " + endpoint if endpoint else ""}, {tx} sent, {rx} received')
status = commands.add_parser('status')
status.set_defaults(func=do_status)
def do_sync(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
log = set()
for server in ctx.get_servers():
server.sync(log=log)
sync = commands.add_parser('sync')
sync.set_defaults(func=do_sync)
# System commands
system = commands.add_parser('system')
system_commands = system.add_subparsers(title='system commands')
def do_setup(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
ctx.setup()
setup = system_commands.add_parser('setup')
setup.set_defaults(func=do_setup)
def do_configure(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
config = ctx.get_config()
config.autosync = args.autosync
config.automanage = args.automanage
config.save()
configure = system_commands.add_parser('configure')
configure.set_defaults(func=do_configure)
configure.add_argument('--autosync', action='store_true', default=False)
configure.add_argument('--automanage', action='store_true', default=False)
def do_migrate(arser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
pass
migrate = system_commands.add_parser('migrate')
migrate.set_defaults(func=do_migrate)
# Host commands
host = commands.add_parser('host')
host_commands = host.add_subparsers(title='host commands')
def do_add_host(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
host = WeegeeHost.create(ctx,
args.name, host=args.host, user=args.user, elevate_user=args.elevate_user,
autosync=args.auto_sync, automanage=args.auto_manage,
)
host.save()
add_host = host_commands.add_parser('create')
add_host.add_argument('-H', '--host', help='remote host')
add_host.add_argument('-u', '--user', help='username')
add_host.add_argument('-U', '--elevate_user', help='username to elevate privileges to')
add_host.add_argument('-a', '--auto-sync', action='store_true', default=False, help='whether to auto-synchronize config')
add_host.add_argument('-A', '--auto-manage', action='store_true', default=False, help='whether to auto-synchronize interfaces')
add_host.add_argument('name', help='host name')
add_host.set_defaults(func=do_add_host)
# Server commands
server = commands.add_parser('server')
server_commands = server.add_subparsers(title='server commands')
def do_add_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
interface = args.interface or f'wg-{args.name}'
server = WeegeeServer.create(ctx,
args.name, interface, (args.host, args.port),
private_key=args.private_key, public_key=args.public_key,
addresses=args.address, routed_addresses=args.routed_address,
)
server.save()
add_server = server_commands.add_parser('create')
add_server.add_argument('-a', '--address', action='append', help='interface address(es)')
add_server.add_argument('-A', '--routed-address', action='append', help='routed address(es)')
add_server.add_argument('-k', '--public-key', help='public key (optional)')
add_server.add_argument('-K', '--private-key', help='private key (optional)')
add_server.add_argument('-i', '--interface', help='interface name (optional)')
add_server.add_argument('name', help='server name')
add_server.add_argument('host', help='endpoint host')
add_server.add_argument('port', type=int, help='listen port')
add_server.set_defaults(func=do_add_server)
def do_del_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
if not args.yes_i_want_to_destroy_this:
parser.error('please pass --yes-i-want-to-destroy-this if you really want to remove this server')
del_server = server_commands.add_parser('destroy')
del_server.add_argument('name', help='server name')
del_server.set_defaults(func=do_del_server)
def do_conf_server(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
server = ctx.get_server(args.name)
print(server.gen_config())
conf_server = server_commands.add_parser('config')
conf_server.add_argument('name', help='server name')
conf_server.set_defaults(func=do_conf_server)
# Client commands
client = commands.add_parser('client')
client_commands = client.add_subparsers(title='client commands')
def do_add_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
server = ctx.get_server(args.server)
client = WeegeeClient.create(ctx, server,
args.name,
private_key=args.private_key, public_key=args.public_key, preshared_key=args.preshared_key,
addresses=args.address,
)
client.save()
add_client = client_commands.add_parser('create')
add_client.add_argument('-a', '--address', action='append', help='interface address(es)')
add_client.add_argument('-k', '--public-key', help='public key (optional)')
add_client.add_argument('-K', '--private-key', help='private key (optional)')
add_client.add_argument('-p', '--preshared-key', help='preshared key (optional)')
add_client.add_argument('name', help='client name')
add_client.add_argument('server', help='server name')
add_client.set_defaults(func=do_add_client)
def do_del_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
pass
del_client = client_commands.add_parser('destroy')
del_client.set_defaults(func=do_del_client)
def do_conf_client(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> O[int]:
server = ctx.get_server(args.server)
client = server.get_client(args.name)
print(client.gen_config())
conf_client = client_commands.add_parser('config')
conf_client.add_argument('name', help='name')
conf_client.add_argument('server', help='server name')
conf_client.set_defaults(func=do_conf_client)
args = parser.parse_args()
if not args.func:
parser.error('a subcommand must be provided')
instance = Instance(args.base_dir)
sys.exit(args.func(parser, args, WeegeeContext(instance)))
main()

634
weegee/dazy.py Normal file
View File

@ -0,0 +1,634 @@
import os
from enum import Enum
from dataclasses import dataclass
from collections import UserString
from ast import literal_eval
import ipaddress
from fnmatch import fnmatch
import jinja2
import jinja2.meta
from types import SimpleNamespace
from io import StringIO
from typing import Optional as O, Union as U, Any, Iterable, Iterator, TypeVar, Generic
def stripsplit(s: str, sep: str, maxsplit: int = -1) -> list[str]:
return [x.strip() for x in s.split(sep, maxsplit=maxsplit)]
def quotesplit(s: str, sep: str, maxsplit: int = - 1) -> list[str]:
parts = []
pos = -1
if maxsplit < 0:
maxsplit = len(s)
while len(parts) < maxsplit:
npos = s.find(sep, pos + 1)
if npos < 0:
break
pos = npos
if pos > 0 and s[pos - 1] == '\\':
s = s[:pos - 1] + s[pos:]
continue
parts.append(s[:pos])
s = s[pos + 1:]
pos = -1
parts.append(s)
return parts
def quotejoin(sep: str, a: list[str]) -> str:
return sep.join(x.replace(sep, '\\' + sep) for x in a)
def indent(s, amount=0, skip=0) -> int:
spacing = ' ' * amount
return '\n'.join((spacing if i >= skip else '') + l for i, l in enumerate(s.splitlines()))
class Instance:
def __init__(self, conf_dir: str, data_dir: O[str] = None) -> None:
self.conf_dir = conf_dir
self.data_dir = data_dir or conf_dir
class Int(int):
@classmethod
def parse(cls, value: str) -> 'Int':
return cls(value)
class Str(UserString):
@classmethod
def parse(cls, value: str) -> 'Str':
val = literal_eval(value)
if not isinstance(val, str):
raise TypeError(value)
return cls(val)
@dataclass
class IPAddress:
address: U[ipaddress.IPv4Address, ipaddress.IPv6Address]
@classmethod
def parse(cls, s: str) -> 'IPAddress':
return cls(ipaddress.ip_address(s))
def __getattr__(self, name: str) -> Any:
return getattr(self.address, name)
def __repr__(self) -> str:
return str(self.address)
@dataclass
class IPNetwork:
network: U[ipaddress.IPv4Network, ipaddress.IPv6Network]
@classmethod
def parse(cls, s: str) -> 'IPNetwork':
return cls(ipaddress.ip_network(s))
def __getattr__(self, name: str) -> Any:
return getattr(self.network, name)
def __repr__(self) -> str:
return str(self.network)
@dataclass
class IPInterface:
interface: U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
@classmethod
def parse(cls, s: str) -> 'IPInterface':
return cls(ipaddress.ip_interface(s))
def __getattr__(self, name: str) -> Any:
return getattr(self.interface, name)
def __repr__(self) -> str:
return str(self.interface)
T = TypeVar('T')
class Type(Generic[T]):
@classmethod
def parse(cls, input: str) -> 'Type':
if input.startswith('?'):
return OptionalType.parse(input)
if input.startswith('['):
return ArrType.parse(input)
if input.startswith('{'):
return MapType.parse(input)
if input.startswith('@'):
return RefType.parse(input)
if input.startswith('*'):
return GlobType.parse(input)
return LitType.parse(input)
def dump(self) -> str:
raise NotImplementedError
def match_value(self, input: Any, instance: Instance) -> bool:
return False
def parse_value(self, input: str, instance: Instance) -> T:
raise TypeError
def dump_value(self, input: T) -> str:
return repr(input)
def template_value(self, input: T) -> T:
return input
@dataclass
class OptionalType(Generic[T], Type[O[T]]):
subtype: Type[T]
@classmethod
def parse(cls, input: str) -> 'OptionalType[T]':
if input[0] != '?':
raise ParseError()
return cls(Type.parse(input[1:]))
def dump(self) -> str:
return f'?{self.subtype.dump()}'
def match_value(self, input: Any, instance: Instance) -> bool:
return input is None or self.subtype.match_value(input, instance)
def parse_value(self, input: str, instance: Instance) -> O[T]:
if not input:
return None
return self.subtype.parse_value(input, instance)
def dump_value(self, input: O[T]) -> str:
if input is None:
return ''
return self.subtype.dump_value(input)
def template_value(self, input: O[T]) -> O[T]:
if input is None:
return None
return self.subtype.template_value(input)
@dataclass
class RefType(Type['Item']):
name: str
@classmethod
def parse(cls, input: str) -> 'RefType':
if input[0] != '@':
raise ParseError()
return cls(input[1:])
def dump(self) -> str:
return f'@{self.name}'
def match_value(self, input: Any, instance: Instance) -> bool:
meta = Meta.load(instance, self.name)
return isinstance(input, Item) and input.is_complete(meta)
def parse_value(self, input: str, instance: Instance) -> 'Item':
meta = Meta.load(instance, self.name)
return Item.load(instance, input)
def dump_value(self, input: 'Item') -> str:
return input.name
def template_value(self, input: 'Item') -> 'Config':
return input.config.template()
@dataclass
class GlobType(Type[dict[str, 'Item']]):
name: str
@classmethod
def parse(cls, input: str) -> 'GlobType':
if input[0] != '*':
raise ParseError()
return cls(input[1:])
def dump(self) -> str:
s = f'*{self.name}'
return s
def match_value(self, input: Any, instance: Instance) -> bool:
meta = Meta.load(instance, self.name)
return isinstance(input, list) and all(isinstance(x, Item) and fnmatch(x.name, input) and x.is_complete(meta) for x in input)
def parse_value(self, input: str, instance: Instance) -> dict[str, 'Item']:
meta = Meta.load(instance, self.name)
out = {}
for x in Item.filter(instance, input):
item = Item.load(instance, x)
if not item.is_complete(meta):
continue
out[x] = item
return out
def dump_value(self, input: dict[str, 'Item']) -> str:
return '<unimplemented>'
def template_value(self, input: dict[str, 'Item']) -> dict[str, 'Item']:
return {x.name: x.config.template() for x in input}
@dataclass
class LitType(Generic[T], Type[T]):
SUBTYPES = {
'int': Int,
'str': Str,
'ipaddr': IPAddress,
'ipnet': IPNetwork,
'ipintf': IPInterface,
}
name: str
subtype: T
@classmethod
def parse(cls, input: str) -> 'LitType[T]':
if input not in cls.SUBTYPES:
raise ParseError()
return cls(name=input, subtype=cls.SUBTYPES[input])
def dump(self) -> str:
return self.name
def match_value(self, input: Any, instance: Instance) -> bool:
return isinstance(input, self.subtype)
def parse_value(self, input: str, instance: Instance) -> T:
return self.subtype.parse(input)
def dump_value(self, input: T) -> str:
return repr(input)
@dataclass
class ArrType(Generic[T], Type[list[T]]):
subtype: Type[T]
@classmethod
def parse(cls, input: str) -> 'ArrType':
if input[0] != '[' or input[-1] != ']':
raise ParseError()
return cls(subtype=Type.parse(input[1:-1]))
def dump(self) -> str:
return f'[{self.subtype.dump()}]'
def match_value(self, input: Any, instance: Instance) -> bool:
return isinstance(input, list) and all(self.subtype.match_value(x, instance) for x in input)
def parse_value(self, input: str, instance: Instance) -> list[T]:
if input[0] != '[' or input[-1] != ']':
raise ParseError()
input = input[1:-1].strip()
if not input:
return []
return [self.subtype.parse_value(x.strip(), instance) for x in quotesplit(input, ',')]
def dump_value(self, input: list[T]) -> str:
return '[' + quotejoin(', ', (self.subtype.dump_value(x) for x in input)) + ']'
def template_value(self, input: list[T]) -> list[T]:
return [self.subtype.template_value(x) for x in input]
V = TypeVar('V')
@dataclass
class MapType(Generic[T, V], Type[dict[T, V]]):
keytype: Type[T]
valtype: Type[V]
@classmethod
def parse(cls, input: str) -> 'MapType[T, V]':
if input[0] != '{' or input[-1] != '}':
raise ParseError()
kt, vt = quotesplit(input[1:-1], ':', maxsplit=1)
return cls(keytype=Type.parse(kt.strip()), valtype=Type.parse(vt.strip()))
def dump(self) -> str:
kt = self.keytype.dump().replace(':', '\\:')
vt = self.valtype.dump()
return f'{{{kt}: {vt}}}'
def match_value(self, input: Any, instance: Instance) -> bool:
return isinstance(input, dict) and all(self.keytype.match_value(k, instance) and self.valtype.match_value(v, instance) for (k, v) in input.items())
def parse_value(self, input: str, instance: Instance) -> tuple[dict[T, V], str]:
if input[0] != '{' or input[-1] != '}':
raise ParseError()
out = {}
input[1:-1].strip()
if input:
items = quotesplit(input, ',')
for item in items:
k, v = quotesplit(item, ':', maxsplit=1)
out[self.keytype.parse_value(k.strip(), instance)] = self.valtype.parse_value(v.strip(), instance)
return out
def dump_value(self, input: dict[T, V]) -> str:
return '{' + quotejoin(', ', (self.keytype.dump_value(k).replace(':', '\\:') + ': ' + self.valtype.dump_value(v) for (k, v) in input.items())) + '}'
def template_value(self, input: dict[T, V]) -> dict[T, V]:
return {self.keytype.template_value(k): self.valtype.template_value(v) for (k, v) in input.items()}
NoValue = object()
@dataclass
class Config:
name: str
resolved: dict[str, tuple[Type, O[Any]]]
unresolved: dict[str, str]
instance: 'Instance'
kind: O[str] = None
@classmethod
def parse(cls, instance: Instance, name: str, lines: Iterable[str]) -> 'Config':
resolved = {}
unresolved = {}
for line in lines:
line, *_ = stripsplit(line, '#', maxsplit=1)
line, *parts = stripsplit(line, '=', maxsplit=1)
valstr = parts[0] if parts else None
vname, *parts = stripsplit(line, ':', maxsplit=1)
typestr = parts[0] if parts else None
if typestr is not None:
vtype = Type.parse(typestr)
if valstr is not None:
vval = vtype.parse_value(valstr, instance)
else:
vval = NoValue
resolved[vname] = (vtype, vval)
else:
unresolved[vname] = valstr
return cls(name=name, resolved=resolved, unresolved=unresolved, instance=instance)
@classmethod
def make(cls, _instance: Instance, _name: str, **args: dict[str, tuple[Type, O[Any]]]) -> 'Config':
return cls(name=_name, resolved=args, unresolved={}, instance=_instance)
def dump(self) -> list[str]:
out = []
for name, (vtype, vval) in self.resolved.items():
line = f'{name}: {vtype.dump()}'
if vval is not NoValue:
line += f' = {vtype.dump_value(vval)}'
out.append(line)
for name, vval in self.unresolved.items():
line = f'{name} = {vval}'
out.append(line)
return out
def template(self) -> SimpleNamespace:
resolved = {}
for name, (vtype, vval) in self.resolved.items():
resolved[name] = vtype.template_value(vval) if vval is not NoValue else vval
return SimpleNamespace(**resolved)
def matches(self, other: 'Config') -> bool:
for vname, (vtype, vval) in other.resolved.items():
if vname in self.unresolved:
vval = vtype.parse_value(self.unresolved[vname], self.instance)
elif vname in self.resolved:
vval = self[vname]
if vval is NoValue:
return False
return True
@property
def type_complete(self) -> bool:
return not self.unresolved
@property
def value_complete(self) -> bool:
return self.type_complete and all(vval is not NoValue for _, vval in self.resolved.values())
def __getitem__(self, name: str) -> Any:
_, value = self.resolved[name]
if value is NoValue:
raise KeyError(name)
return value
def __setitem__(self, name: str, value: Any) -> None:
vtype, _ = self.resolved[name]
if not vtype.match_value(value, self.instance):
raise TypeError(f'{name}: {value} must match {vtype.dump()}')
self.resolved[name] = (vtype, value)
def __contains__(self, name: str) -> bool:
if name not in self.resolved:
return False
_, value = self.resolved[name]
return value is not NoValue
def __iter__(self) -> Iterator[str]:
return iter(self.resolved)
def __repr__(self) -> str:
value = self.name
if self.kind:
value += f': {self.kind}'
value += ' {\n'
for name, (vtype, vval) in self.resolved.items():
valstr = f' {name}: {vtype.dump()}'
if vval is not NoValue:
valstr += f' = {indent(str(vval), amount=2, skip=1)}'
value += valstr + '\n'
for name, vval in self.unresolved.items():
value += f' {name}: ?{indent(str(vval), amount=2, skip=1)}\n'
value += '}'
return value
@dataclass
class Meta:
config: Config
@classmethod
def load(cls, instance: 'Instance', name: str) -> 'Meta':
with open(f'{instance.conf_dir}/configs/{name}.conf', 'r') as f:
config = Config.parse(instance, name, f.readlines())
return cls(config=config)
def save(self) -> None:
path = f'{self.config.instance.conf_dir}/configs/{self.name}.conf'
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write('\n'.join(self.config.dump()) + '\n')
@property
def name(self) -> str:
return self.config.name
@property
def is_complete(self) -> bool:
return self.config.type_complete
def check_complete(self) -> None:
if not self.is_complete:
raise ValueError(f'Incomplete config for meta: {self.config}')
@dataclass
class Item:
config: Config
@classmethod
def filter(cls, instance: Instance, pattern: str) -> str:
basedir = f'{instance.data_dir}/items/'
allfiles = []
for root, _, files in os.walk(basedir):
for f in files:
if not f.endswith('.conf'):
continue
base = os.path.join(root[len(basedir):], f[:-len('.conf')])
if not fnmatch(base, pattern):
continue
allfiles.append(base)
return allfiles
@classmethod
def load(cls, instance: Instance, name: str) -> None:
with open(f'{instance.data_dir}/items/{name}.conf', 'r') as f:
config = Config.parse(instance, name, f)
return cls(config=config)
def save(self) -> None:
path = f'{self.config.instance.data_dir}/items/{self.name}.conf'
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write('\n'.join(self.config.dump()) + '\n')
@property
def name(self) -> str:
return self.config.name
def resolve(self, meta: Meta) -> O['Item']:
c = Config(name=self.name, resolved={}, unresolved={}, instance=self.config.instance, kind=meta.name)
for name, (vtype, vval) in meta.config.resolved.items():
c.resolved[name] = (vtype, NoValue)
if name in self.config.resolved:
vval = self.config[name]
elif name in self.config.unresolved:
vval = vtype.parse_value(self.config.unresolved[name], self.config.instance)
if vval is not NoValue:
c[name] = vval
return self.__class__(c)
def is_complete(self, meta: Meta) -> bool:
return self.resolve(meta).config.value_complete
def check_complete(self, meta: Meta) -> None:
if not self.is_complete(meta):
raise ValueError(f'incomplete config {self.config.name} for {meta.name}: {self.config}')
def __str__(self) -> str:
return str(self.config)
@dataclass
class Template:
name: str
source: str
instance: Instance
template: O[jinja2.Template] = None
def __post_init__(self) -> None:
if not self.template:
self.template = jinja2.Template(self.source)
@classmethod
def load(cls, instance: Instance, name: str) -> None:
env = jinja2.Environment()
with open(f'{instance.conf_dir}/templates/{name}.tmpl', 'r') as f:
source = f.read()
return cls(name=name, source=source, instance=instance)
def save(self) -> None:
path = f'{self.instance.conf_dir}/templates/{self.name}.tmpl'
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write(self.source)
def get_variables(self) -> set[str]:
return jinja2.meta.find_undeclared_variables(jinja2.Environment().parse(self.source))
def render(self, config: Config) -> str:
templated = config.template()
return self.template.render(**templated.__dict__)
if __name__ == '__main__':
import sys
import argparse
parser = argparse.ArgumentParser()
parser.set_defaults(func=None)
parser.add_argument('-d', '--base-dir', default='.', help='Base directory for configuration')
commands = parser.add_subparsers(title='commands')
def split_meta(name: str, instance: Instance) -> tuple[str, O[Meta]]:
if ':' in name:
name, metaname = name.split(':', maxsplit=1)
meta = Meta.load(instance, metaname)
else:
meta = None
return name, meta
def do_new_meta(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None:
meta = Meta(config=Config.parse(instance, args.name, args.spec))
meta.save()
new_meta = commands.add_parser('new-meta')
new_meta.add_argument('name')
new_meta.add_argument('spec', nargs='*')
new_meta.set_defaults(func=do_new_meta)
def do_new_item(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None:
name, meta = split_meta(args.name, instance)
item = Item(config=Config.parse(instance, name, args.spec))
if meta:
meta.check_complete()
item.check_complete(meta)
item.save()
new_item = commands.add_parser('new')
new_item.add_argument('name')
new_item.add_argument('spec', nargs='*')
new_item.set_defaults(func=do_new_item)
def do_check_item(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None:
name, meta = split_meta(args.name, instance)
item = Item.load(instance, name)
if meta:
meta.check_complete()
return 0 if item.is_complete(meta) else 1
else:
return 0
check_item = commands.add_parser('check')
check_item.add_argument('name')
check_item.set_defaults(func=do_check_item)
def do_template(parser: argparse.ArgumentParser, args: argparse.Namespace, instance: Instance) -> None:
template = Template.load(instance, args.name)
config = Config.parse(instance, '<vars>', args.spec)
missing = template.get_variables() - set(config)
if missing:
parser.error(f'missing variables: {", ".join(missing)}')
print(template.render(config))
template_item = commands.add_parser('template')
template_item.add_argument('name')
template_item.add_argument('spec', nargs='*')
template_item.set_defaults(func=do_template)
args = parser.parse_args()
if not args.func:
parser.error('a subcommand must be provided')
instance = Instance(args.base_dir)
sys.exit(args.func(parser, args, instance))

132
weegee/desc.py Normal file
View File

@ -0,0 +1,132 @@
from dataclasses import dataclass
from typing import Union as U
from .dazy import Instance, Config, RefType, ArrType, Meta
@dataclass
class WeegeeDesc:
name: str
version: int
def get_name(self) -> str:
return f'{self.name}@{self.version}'
@dataclass
class WeegeeMeta(WeegeeDesc):
spec: list[str]
item_prefix: str = ''
WEEGEE_HOST = WeegeeMeta(
name='wg/host',
version=1,
spec=[
'autosync: int = 0',
'automanage: int = 0',
'host: ?str = ',
'user: ?str = ',
'elevate_user: ?str = ',
],
item_prefix='wg/host',
)
WEEGEE_SERVER = WeegeeMeta(
name='wg/server',
version=1,
spec=[
f'hosts: [@{WEEGEE_HOST.get_name()}]',
'interface: str',
'public_key: str',
'private_key: str',
'addresses: [ipintf]',
'routed_addresses: [ipintf]',
'host: str',
'port: int',
],
item_prefix='wg/server',
)
WEEGEE_CLIENT = WeegeeMeta(
name='wg/client',
version=1,
spec=[
f'hosts: [@{WEEGEE_HOST.get_name()}]',
f'server: @{WEEGEE_SERVER.get_name()}',
'interface: str = "wg0"',
'public_key: str',
'private_key: str',
'preshared_key: str',
'addresses: [ipintf]',
],
item_prefix='wg/client',
)
WEEGEE_CONFIG = WeegeeMeta(
name='wg/config',
version=1,
spec=[
f'default_server_hosts: [@{WEEGEE_HOST.get_name()}] = []',
f'default_client_hosts: [@{WEEGEE_HOST.get_name()}] = []',
],
)
@dataclass
class WeegeeTemplate(WeegeeDesc):
template: str
variables: dict[str, U[WeegeeMeta, list[WeegeeMeta]]]
def make_config(self, instance: Instance, **kwargs) -> 'Config':
args = {}
for k, v in kwargs.items():
vtype = self.variables[k]
if isinstance(vtype, list):
tname = vtype[0].get_name()
ttype = ArrType(RefType(tname))
vval = [x.resolve(Meta.load(instance, tname)) for x in v]
else:
tname = vtype.get_name()
ttype = RefType(tname)
vval = v.resolve(Meta.load(instance, tname))
args[k] = (ttype, vval)
return Config.make(instance, '<vars>', **args)
WEEGEE_SERVER_CONF = WeegeeTemplate(
name='wg/server-conf',
version=1,
template="""
[Interface]
Address = {{server.addresses | join(', ')}}
ListenPort = {{server.port}}
PrivateKey = {{server.private_key}}
PostUp = iptables -A FORWARD -i %i -j ACCEPT
PostDown = iptables -D FORWARD -i %i -j ACCEPT
{% for client in clients -%}
[Peer]
# Client: {{client.name}}
PublicKey = {{client.public_key}}
PresharedKey = {{client.preshared_key}}
AllowedIPs = {{client.addresses | join(', ')}}
{% endfor %}
""".strip(),
variables={'server': WEEGEE_SERVER, 'clients': [WEEGEE_CLIENT]},
)
WEEGEE_CLIENT_CONF = WeegeeTemplate(
name='wg/client-conf',
version=1,
template="""
[Interface]
PrivateKey = {{client.private_ke}}
Address = {{client.addresses | join(', ')}}
[Peer]
PublicKey = {{client.server.public_key}}
PresharedKey = {{client.preshared_key}}
AllowedIPs = {{client.server.routed_addresses | join(', ')}}
Endpoint = {{client.server.host}}:{{client.server.port}}
PersistentKeepalive = 30
""".strip(),
variables={'client': WEEGEE_CLIENT},
)

125
weegee/wireguard.py Normal file
View File

@ -0,0 +1,125 @@
from dataclasses import dataclass
from datetime import datetime
import shlex
import ipaddress
from typing import Optional as O, Union as U
import subprocess
from logging import getLogger
logger = getLogger(__name__)
@dataclass
class WireguardPeer:
conn: 'WireguardConnection'
interface: str
name: str
def _filter_list(self, cmd: str) -> O[list[str]]:
for line in self.conn._run_wg('show', self.interface, cmd).splitlines():
peer, *parts = line.split()
if peer == self.name:
return parts
def last_handshake(self) -> O[datetime]:
parts = self._filter_list('latest-handshakes')
if not parts:
return None
ts = int(parts[0])
if not ts:
return None
return datetime.fromtimestamp(ts)
def last_endpoint(self) -> O[tuple[U[ipaddress.IPv4Address, ipaddress.IPv6Address], int]]:
parts = self._filter_list('endpoints')
if not parts:
return None
if parts[0] == '(none)':
return None
addr, port = parts[0].rsplit(':', maxsplit=1)
return (ipaddress.ip_address(addr.strip('[]')), int(port))
def get_total_transfer(self) -> O[tuple[int, int]]:
parts = self._filter_list('transfer')
if not parts:
return None
return (int(parts[0]), int(parts[1]))
@dataclass
class WireguardInterface:
conn: 'WireguardConnection'
name: str
def list_peers(self) -> list[WireguardPeer]:
return [WireguardPeer(self.conn, self.name, x) for x in self.conn._run_wg('show', self.name, 'peers').split()]
def get_config(self) -> str:
return self.conn._run_wg('getconf', self.name)
def set_config(self, config: str) -> None:
config = self.conn._run_wg_quick('strip', self.name, config)
self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin=config)
def sync_config(self, config: str) -> None:
config = self.conn._run_wg_quick('strip', self.name, config)
self.conn._run_wg('syncconf', self.name, '/dev/stdin', stdin=config)
def delete(self) -> None:
self.conn._run_wg_quick('down', self.name)
@dataclass
class WireguardConnection:
host: O[str]
user: O[str]
elevate_user: O[str]
def _run(self, *args, shell=False, stdin=None, **kwargs) -> str:
cmd = []
if self.host:
cmd += ['ssh', f'{self.user}@{self.host}' if self.user else str(self.host)]
if self.elevate_user:
cmd += ['sudo', '-i', '-u', self.elevate_user]
if shell:
cmd += ['sh', '-c', shlex.quote('set -e; ' + args[0])]
else:
cmd += list(args)
logger.debug(f'executing: {shlex.join(cmd)}')
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='ascii', **kwargs)
stdout, _ = p.communicate(stdin)
return stdout.strip()
def _run_wg(self, *args, **kwargs) -> str:
return self._run('wg', *args, **kwargs)
def _run_wg_quick(self, command, interface, config=None, **kwargs) -> str:
cmd = ''
stdin = None
if config:
cmd += f'tmpdir="$(mktemp -d)"; tmpfile="$tmpdir/{shlex.quote(str(interface) + ".conf")}"; trap \'rm -rf "$tmpdir"\' EXIT; cat > "$tmpfile"; intf="$tmpfile"; '
stdin = config
else:
cmd += f'intf={shlex.quote(interface)}; '
cmd += f'exec wg-quick {shlex.quote(command)} "$intf"'
return self._run(cmd, shell=True, stdin=stdin, **kwargs)
def list_interfaces(self) -> list[WireguardInterface]:
return [WireguardInterface(self, x) for x in self._run_wg('show', 'interfaces').split()]
def get_interface(self, name: str) -> O[WireguardInterface]:
for interface in self.list_interfaces():
if interface.name == name:
return interface
return None
def create_interface(self, name: str, config: str) -> WireguardInterface:
self._run_wg_quick('up', name, config)
return self.get_interface(name)
def gen_preshared_key(self) -> str:
return self._run_wg('genpsk')
def gen_private_key(self) -> str:
return self._run_wg('genkey')
def get_public_key(self, privkey: str) -> str:
return self._run_wg('pubkey', stdin=privkey)