From 0313c8d938ec5221fed1c051ae9882d8457e9a21 Mon Sep 17 00:00:00 2001 From: Shiz Date: Thu, 24 Jun 2021 03:29:19 +0200 Subject: [PATCH] init --- .gitignore | 6 + sx/__init__.py | 21 +++ sx/core/__init__.py | 109 +++++++++++++++ sx/core/base.py | 199 ++++++++++++++++++++++++++++ sx/core/expr.py | 294 +++++++++++++++++++++++++++++++++++++++++ sx/core/io.py | 166 +++++++++++++++++++++++ sx/core/meta.py | 87 ++++++++++++ sx/core/util.py | 67 ++++++++++ sx/types/__init__.py | 0 sx/types/data.py | 29 ++++ sx/types/int.py | 73 ++++++++++ sx/types/seq.py | 200 ++++++++++++++++++++++++++++ sx/types/str.py | 54 ++++++++ sx/types/struct.py | 292 ++++++++++++++++++++++++++++++++++++++++ sx/types/transforms.py | 196 +++++++++++++++++++++++++++ 15 files changed, 1793 insertions(+) create mode 100644 .gitignore create mode 100644 sx/__init__.py create mode 100644 sx/core/__init__.py create mode 100644 sx/core/base.py create mode 100644 sx/core/expr.py create mode 100644 sx/core/io.py create mode 100644 sx/core/meta.py create mode 100644 sx/core/util.py create mode 100644 sx/types/__init__.py create mode 100644 sx/types/data.py create mode 100644 sx/types/int.py create mode 100644 sx/types/seq.py create mode 100644 sx/types/str.py create mode 100644 sx/types/struct.py create mode 100644 sx/types/transforms.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f313972 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.pyc +.mypy_cache + +.DS_Store +thumbs.db diff --git a/sx/__init__.py b/sx/__init__.py new file mode 100644 index 0000000..26ffaf6 --- /dev/null +++ b/sx/__init__.py @@ -0,0 +1,21 @@ +from .core import parse, dump, sizeof, offsetof, default +from .core.base import Context, Type +from .core.io import Stream, Segment +from .core.meta import Wrapper, Generic +#from .core.expr import Ref, RefSource + +from .types.data import Data, data +from .types.int import * +from .types.struct import StructType, Struct +from .types.seq import Arr, Tuple +from .types.transforms import Default, Sized, Ref, Transform, Mapped + +__all__ = [x.__name__ for x in { + parse, dump, sizeof, offsetof, default, + Context, Type, Stream, Segment, + Wrapper, Default, Sized, Ref, Transform, Mapped, + Data, + Int, Bool, + Arr, Tuple, + StructType, Struct, Generic, +}] + ['data', 'bool', 'int8', 'uint8', 'uint16le', 'int32le', 'uint32le'] diff --git a/sx/core/__init__.py b/sx/core/__init__.py new file mode 100644 index 0000000..60af122 --- /dev/null +++ b/sx/core/__init__.py @@ -0,0 +1,109 @@ +from typing import Callable, Union, BinaryIO, Any, Optional as O, Sequence +from types import FunctionType +from io import BytesIO +import math + +from .base import PathElement, Context, Params, Type, Error +from .io import Stream, Segment, ceil_sizes + + +PossibleType = Union[Type, list, tuple, Callable[[O[Any]], Type]] + +def to_type(type: PossibleType, ident: O[Any] = None) -> Type: + if isinstance(type, Type): + return type + t = getattr(type, '_sx_type_', None) + if t: + return t + getter = getattr(type, '_get_sx_type_', None) + if getter: + return getter(ident) + if isinstance(type, FunctionType): + return type(ident) + + raise ValueError('Could not figure out specification from argument {}.'.format(type)) + +PossibleStream = Union[BinaryIO, Stream, None, bytes, bytearray] + +def to_stream(value: PossibleStream) -> Stream: + if isinstance(value, Stream): + return value + if value is None: + value = BytesIO() + if isinstance(value, (bytes, bytearray)): + value = BytesIO(value) + return Stream(value) + + +def parse(type: PossibleType, stream: PossibleStream, params: O[Params] = None) -> Any: + type = to_type(type) + stream = to_stream(stream) + ctx = Context(type, None, params=params) + try: + return ctx.parse(type, stream) + except Error as e: + raise + except Exception as e: + raise Error(ctx, e) from e + +def dump(type: PossibleType, value: Any, stream: PossibleStream = None, params: O[Params] = None) -> BinaryIO: + type = to_type(type) + stream = to_stream(stream) + ctx = Context(type, value, params=params) + try: + ctx.dump(type, stream, value) + except Error: + raise + except Exception as e: + raise Error(ctx, e) from e + return stream.root + +def sizeof(type: PossibleType, value: O[Any] = None, params: O[Params] = None, segment: O[Segment] = None) -> O[int]: + type = to_type(type) + ctx = Context(type, value, params=params) + try: + sizes = ceil_sizes(ctx.sizeof(type, value)) + except Error: + raise + except Exception as e: + raise Error(ctx, e) from e + + if segment: + return sizes.get(segment, None) + else: + size = 0 + for v in sizes.values(): + if v is None: + return None + size += v + return size + +def offsetof(type: PossibleType, path: Sequence[PathElement], value: O[Any] = None, params: O[Params] = None, segment: O[Segment] = None) -> O[int]: + type = to_type(type) + ctx = Context(type, value, params=params) + try: + offsets = ctx.offsetof(type, path, value) + except Error: + raise + except Exception as e: + raise Error(ctx, e) from e + + segment = segment or ctx.params.default_segment + + off = offsets.get(segment, None) + if off is None: + return None + segoff = ctx.segment_offset(segment) + if segoff is None: + return None + return math.ceil(segoff + off) + +def default(type: PossibleType, params: O[Params] = None) -> O[Any]: + type = to_type(type) + ctx = Context(type, None, params=params) + try: + return ctx.default(type) + except Error: + raise + except Exception as e: + raise Error(ctx, e) from e diff --git a/sx/core/base.py b/sx/core/base.py new file mode 100644 index 0000000..b0adc8b --- /dev/null +++ b/sx/core/base.py @@ -0,0 +1,199 @@ +import os +from types import SimpleNamespace + +from contextlib import contextmanager +from typing import Any, Generic, Generator, Iterable, List, Mapping, Dict, Sequence, Tuple, TypeVar, Union as U, Optional as O, Generic as G, cast + +from .util import seeking +from .io import Segment, Stream, Pos + + +class Params: + __slots__ = ('segments', 'default_segment', 'user') + + def __init__(self, segments: Sequence[Segment] = None): + default = segments[0] if segments else Segment('default') + self.segments = {s.name: s for s in (segments or [default, Segment('refs', [default])])} + self.default_segment = default + self.user = SimpleNamespace() + + def reset(self): + for s in self.segments.values(): + s.reset() + + +PathElement = U[str, int] +PathEntry = Tuple[PathElement, 'Type'] + +def format_path(path: Iterable[PathElement]) -> str: + s = '' + first = True + for p in path: + sep = '.' + if isinstance(p, int): + p = '[' + str(p) + ']' + sep = '' + if sep and not first: + s += sep + s += p + first = False + return s + + +T = TypeVar('T') +PT = TypeVar('PT') + + +class PossibleDynamic(Generic[T]): + pass + +class Context: + __slots__ = ('root', 'value', 'params', 'path', 'segment_path') + + def __init__(self, root: 'Type', value: O[Any] = None, params: O[Params] = None) -> None: + self.root = root + self.value = value + self.params = params or Params() + self.path: List[PathEntry] = [] + self.segment_path: List[Segment] = [] + + def copy(self) -> 'Context': + c = self.__class__(root=self.root, value=self.value, params=self.params) + c.path = self.path.copy() + c.segment_path = self.segment_path.copy() + return c + + + @property + def segment(self) -> Segment: + return self.segment_path[-1] if self.segment_path else self.params.default_segment + + @contextmanager + def enter(self, entry: PathElement, type: 'Type') -> Generator: + self.path.append((entry, type)) + try: + yield + except EOFError as e: + raise EOF(self, e) from e + self.path.pop() + + @contextmanager + def enter_segment(self, segment: Segment, stream: O[Stream] = None, pos: O[Pos] = None, reference = os.SEEK_SET) -> Generator[O[Stream], None, None]: + if stream: + if pos is None: + if segment.offset is None: + segment.offset = self.segment_offset(segment) + segment.pos = segment.offset + pos = segment.pos + if pos is None: + raise Error(self, ValueError('could not enter segment {}: could not calculate offset'.format(segment))) + with seeking(stream.root, pos, reference) as s, stream.wrapped(s) as f: + self.segment_path.append(segment) + yield f + self.segment_path.pop() + segment.pos = f.tell() + else: + self.segment_path.append(segment) + yield stream + self.segment_path.pop() + + def segment_offset(self, segment: Segment) -> O[Pos]: + size: Pos = 0 + for s in segment.dependents: + sz = self.segment_size(s) + if sz is None: + return None + off = self.segment_offset(s) + if off is None: + return None + size += off + sz + return size + + def segment_size(self, segment: Segment) -> O[Pos]: + sizes = self.sizeof(self.root, self.value) + return sizes.get(segment, None) + + def format_path(self) -> str: + return format_path(name for name, _ in self.path) + + def to_size(self, value: Any) -> Dict[Segment, Pos]: + if not isinstance(value, dict): + stream = self.segment_path[-1] if self.segment_path else self.params.default_segment + value = {stream: value} + return value + + + def get(self, value: U[T, PossibleDynamic[T]]) -> T: + from .expr import Expr, get + if isinstance(value, Expr): + value = get(value) + return cast(T, value) + + def peek(self, value: U[T, PossibleDynamic[T]]) -> O[T]: + from .expr import Expr, peek + if isinstance(value, Expr): + value = peek(value) + return cast(T, value) + + def put(self, value: U[T, PossibleDynamic[T]], new: T) -> None: + from .expr import Expr, put + if isinstance(value, Expr): + put(value, new) + + + def parse(self, type: 'Type[PT]', stream: Stream) -> PT: + return type.parse(self, stream) + + def dump(self, type: 'Type[PT]', stream: Stream, value: PT) -> None: + return type.dump(self, stream, value) + + def sizeof(self, type: 'Type[PT]', value: O[PT] = None) -> Dict[Segment, Pos]: + return self.to_size(type.sizeof(self, value)) + + def offsetof(self, type: 'Type[PT]', path: Sequence[PathElement], value: O[PT] = None) -> Dict[Segment, Pos]: + return self.to_size(type.offsetof(self, path, value)) + + def default(self, type: 'Type[PT]') -> PT: + return type.default(self) + + +class Type(G[PT]): + __slots__ = () + + def parse(self, context: Context, stream: Stream) -> PT: + raise NotImplementedError + + def dump(self, context: Context, stream: Stream, value: PT) -> None: + raise NotImplementedError + + def sizeof(self, context: Context, value: O[PT]) -> U[Mapping[str, int], O[int]]: + return None + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[PT]) -> O[int]: + if path: + return None + else: + return 0 + + def default(self, context: Context) -> PT: + raise NotImplementedError + + +class Error(Exception): + __slots__ = ('context',) + + def __init__(self, context: Context, exception: Exception) -> None: + path = context.format_path() + if path: + path = '[' + path + '] ' + if not isinstance(exception, Exception): + exception = ValueError(exception) + + super().__init__('{}{}: {}'.format( + path, exception.__class__.__name__, str(exception), + )) + self.exception = exception + self.context = context.copy() + +class EOF(Error): + pass diff --git a/sx/core/expr.py b/sx/core/expr.py new file mode 100644 index 0000000..42ee3a1 --- /dev/null +++ b/sx/core/expr.py @@ -0,0 +1,294 @@ +import os +import math +import operator +import functools +from typing import Any, Optional as O, Sequence, Mapping, Callable, Generic as G, TypeVar, Tuple + +from .base import Type, Context, PathElement +from .io import Stream, Segment, Pos +from .meta import Wrapper +from . import to_type + + +class VarSource(Wrapper): + def __init__(self, child: Type, count: int) -> None: + super().__init__(child) + self.stack: list[Tuple[Context, Segment, Stream, Pos, Any]] = [] + self.pstack: list[Any] = [] + self.count = count + + def parse(self, context: Context, stream: Stream) -> Any: + pos = stream.tell() + value = super().parse(context, stream) + for _ in range(self.count): + self.stack.append((context, context.segment, stream, pos, value)) + return value + + def dump(self, context: Context, stream: Stream, value: Any) -> None: + pos = stream.tell() + for _ in range(self.count): + self.stack.append((context, context.segment, stream, pos, value)) + super().dump(context, stream, value) + + def sizeof(self, context: Context, value: O[Any]) -> None: + for _ in range(self.count): + self.pstack.append(value) + return super().sizeof(context, value) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[Any]) -> None: + for _ in range(self.count): + self.pstack.append(value) + return super().offsetof(context, path, value) + + def default(self, context: Context) -> Any: + value = super().default(context) + for _ in range(self.count): + self.pstack.append(value) + return value + + +symbols = { + operator.lt: '<', + operator.le: '<=', + operator.eq: '==', + operator.ne: '!=', + operator.ge: '>=', + operator.gt: '>', + + operator.not_: 'not ', + operator.truth: 'bool ', + operator.abs: 'abs ', + operator.index: 'int ', + operator.inv: '~', + operator.neg: '-', + operator.pos: '+', + + operator.add: '+', + operator.and_: '&', + operator.floordiv: '//', + operator.lshift: '<<', + operator.mod: '%', + operator.mul: '*', + operator.matmul: '@', + operator.or_: '|', + operator.pow: '**', + operator.rshift: '>>', + operator.sub: '-', + operator.truediv: '/', + operator.xor: '^', +} +reverse = { + operator.not_: operator.not_, + + operator.add: operator.sub, + operator.sub: operator.add, + operator.truediv: operator.mul, + operator.mul: operator.truediv, + operator.pos: operator.pos, + operator.neg: operator.neg, + operator.pow: lambda x, y: math.log(x) / math.log(y), + + operator.inv: operator.inv, + operator.lshift: operator.rshift, + operator.rshift: operator.lshift, +} + +T = TypeVar('T') + + +class Expr(G[T]): + def _sx_get_(self) -> T: + raise NotImplementedError + + def _sx_peek_(self) -> T: + raise NotImplementedError + + def _sx_put_(self, value: T) -> None: + raise NotImplementedError + + def __getattr__(self, name: str) -> 'AttrExpr': + return AttrExpr(self, name) + + def __getitem__(self, item: Any) -> 'ItemExpr': + return ItemExpr(self, item) + + def __call__(self, *args: Any, **kwargs: Any) -> 'CallExpr': + return CallExpr(self, args, kwargs) + + for x in ('lt', 'le', 'eq', 'ne', 'ge', 'gt'): + locals()['__' + x + '__'] = functools.partialmethod(lambda self, x, other: CompExpr(getattr(operator, x), self, other), x) + for x in ('not_', 'truth', 'abs', 'index', 'inv', 'neg', 'pos'): + locals()['__' + x + '__'] = functools.partialmethod(lambda self, x: UnaryExpr(getattr(operator, x), self), x) + for x in ( + 'add', 'and_', 'floordiv', 'lshift', 'mod', 'mul', 'matmul', 'or_', 'pow', 'rshift', 'sub', 'truediv', 'xor', + 'concat', 'contains', 'delitem', 'getitem', 'delitem', 'getitem', 'setitem', + ): + locals()['__' + x + '__'] = functools.partialmethod(lambda self, x, other: BinExpr(getattr(operator, x), self, other), x) + del x + +class AttrExpr(G[T], Expr[T]): + def __init__(self, parent: Expr, attr: str) -> None: + self.__parent = parent + self.__attr = attr + + def _sx_get_(self) -> T: + return getattr(get(self.__parent), get(self.__attr)) + + def _sx_peek_(self) -> T: + return getattr(peek(self.__parent), peek(self.__attr)) + + def _sx_put_(self, value: T) -> None: + parent = peek(self.__parent) + setattr(parent, peek(self.__attr), value) + put(self.__parent, parent) + + def __str__(self) -> str: + return f'{self.__parent}.{self.__attr}' + + def __repr__(self) -> str: + return f'{self.__parent!r}.{self.__attr}' + +class ItemExpr(G[T], Expr[T]): + def __init__(self, parent: Expr, item: Any) -> None: + self.__parent = parent + self.__item = item + + def _sx_get_(self) -> T: + return get(self.__parent)[get(self.__item)] + + def _sx_peek_(self) -> T: + return peek(self.__parent)[peek(self.__item)] + + def _sx_put_(self, value: T) -> None: + parent = peek(self.__parent) + parent[peek(self.__item)] = value + put(self.__parent, parent) + + def __repr__(self) -> str: + return f'{self.__parent}[{self.__item}]' + + def __repr__(self) -> str: + return f'{self.__parent!r}[{self.__item!r}]' + +class VarExpr(G[T], Expr[T]): + def __init__(self, name: str) -> None: + self.__name = name + self.__source: O[VarSource] = None + + def _sx_resolve_(self, value: VarSource) -> None: + self.__source = value + + def _sx_get_(self) -> T: + _, _, _, _, value = self.__source.stack.pop() + return value + + def _sx_peek_(self) -> T: + _, _, _, _, value = self.__source.stack[0] + return value + + def _sx_put_(self, value: T) -> None: + context, segment, stream, pos, _ = self.__source.stack.pop() + with context.enter_segment(segment, stream, pos, os.SEEK_SET) as f: + context.dump(to_type(self.__source.child), f, value) + + def __str__(self) -> str: + return f'${self.__name}:{self.__source}' + + def __repr__(self) -> str: + return f'${self.__name}(=> {self.__source!r})' + +class CallExpr(G[T], Expr[T]): + def __init__(self, parent: Expr, args: Sequence[Any], kwargs: Mapping[str, Any]) -> None: + self.__parent = parent + self.__args = args + self.__kwargs = kwargs + + def _sx_get_(self) -> T: + return get(self.__parent)(*(get(a) for a in self.__args), **{k: get(v) for k, v in self.__kwargs.items()}) + + def _sx_peek_(self) -> T: + return peek(self.__parent)(*(peek(a) for a in self.__args), **{k: peek(v) for k, v in self.__kwargs.items()}) + + def _sx_put_(self, value: T) -> None: + raise NotImplementedError(f'{self.__class__.__name__} is not invertible') + + def __str__(self) -> str: + args = [repr(a) for a in self.__args] + args += [f'{k}: {v}' for k, v in self.__kwargs.items()] + a = ', '.join(args) + return f'{self.__parent}({a})' + + def __repr__(self) -> str: + args = [repr(a) for a in self.__args] + args += [f'{k} = {v!r}' for k, v in self.__kwargs.items()] + a = ', '.join(args) + return f'{self.__parent!r}({a})' + + +class UnaryExpr(G[T], Expr[T]): + def __init__(self, op: Callable[[Expr], T], value: Expr) -> None: + self.__op = op + self.__value = value + + def _sx_get_(self) -> T: + return self.__op(get(self.__value)) + + def _sx_peek_(self) -> T: + return self.__op(peek(self.__value)) + + def _sx_put_(self, value: T) -> None: + if self.__op not in reverse: + raise NotImplementedError(f'{self.__class__.__name__} {symbols[self.__op]!r} is not invertible') + put(self.__value, reverse[self.__op](value)) + + def __str__(self) -> str: + return f'({symbols[self.__op]}{self.__value})' + + def __repr__(self) -> str: + return f'({symbols[self.__op]}{self.__value!r})' + +class BinExpr(G[T], Expr[T]): + def __init__(self, op: Callable[[Expr, Expr], T], left: Expr, right: Expr) -> None: + self.__op = op + self.__left = left + self.__right = right + + def _sx_get_(self) -> T: + return self.__op(get(self.__left), get(self.__right)) + + def _sx_peek_(self) -> T: + return self.__op(peek(self.__left), peek(self.__right)) + + def _sx_put_(self, value: T) -> None: + if not isinstance(self.__left, Expr): + operand = self.__left + target = self.__right + elif not isinstance(self.__right, Expr): + operand = self.__right + target = self.__left + else: + raise NotImplementedError(f'{self.__class__.__name__} has two expression operands and is not invertible') + if self.__op not in reverse: + raise NotImplementedError(f'{self.__class__.__name__} {symbols[self.__op]!r} is not invertible') + put(target, reverse[self.__op](value, operand)) + + def __str__(self) -> str: + return f'({self.__left} {symbols[self.__op]} {self.__right})' + + def __repr__(self) -> str: + return f'({self.__left!r} {symbols[self.__op]} {self.__right!r})' + + +def get(expr: Any) -> Any: + if isinstance(expr, Expr): + return expr._sx_get_() + return expr + +def peek(expr: Any) -> Any: + if isinstance(expr, Expr): + return expr._sx_peek_() + return expr + +def put(expr: Any, value: Any) -> None: + if isinstance(expr, Expr): + expr._sx_put_(value) diff --git a/sx/core/io.py b/sx/core/io.py new file mode 100644 index 0000000..3947dc6 --- /dev/null +++ b/sx/core/io.py @@ -0,0 +1,166 @@ +import os +import enum +import math +from contextlib import contextmanager +from typing import BinaryIO, Sequence, Optional as O, Union, Mapping, Callable, Dict + +from .util import bits + + +class BitAlignment(enum.Enum): + No = enum.auto() + Fill = enum.auto() + Yes = enum.auto() + +class Endian(enum.Enum): + Little = enum.auto() + Big = enum.auto() + + def to_python(self): + return {self.Little: 'little', self.Big: 'big'}[self] + + +Pos = Union[int, float] + + +class Stream: + def __init__(self, handle: BinaryIO, bit_align: BitAlignment = BitAlignment.No, bit_endian: Endian = Endian.Big) -> None: + self.handle = handle + self.bit_pos: O[int] = None + self.bit_val: O[int] = None + self.bit_align = bit_align + self.bit_endian = bit_endian + + @property + def root(self): + h = self.handle + while isinstance(h, self.__class__): + h = self.handle + return h + + @contextmanager + def wrapped(self, handle): + #self.flush() + old = self.handle + self.handle = handle + yield self + self.handle = old + + + def read_bits(self, n: int): + if n <= 0: + return (0, 0) + + if self.bit_pos is None or self.bit_val is None: + self.bit_val = self.read(1)[0] + self.bit_pos = 0 + + nb = min(8 - self.bit_pos, n) + if self.bit_endian == Endian.Big: + val = bits(self.bit_val, self.bit_pos, nb) + else: + val = bits(self.bit_val, 8 - self.bit_pos - nb, nb) + + self.bit_pos += nb + if self.bit_pos == 8: + self.bit_pos = self.bit_val = None + + return val, n - nb + + def read(self, n: int = -1, bits=False) -> bytes: + if bits: + val, nl = self.read_bits(n) + if self.bit_endian == Endian.Big: + val <<= nl + if nl >= 8: + rounds = nl // 8 + v = int.from_bytes(self.read(rounds), byteorder=self.bit_endian.to_python()) + if self.bit_endian == Endian.Big: + nl -= rounds * 8 + v <<= nl + else: + v <<= (n - nl) + nl -= rounds * 8 + val |= v + if nl > 0: + v, _ = self.read_bits(nl) + if self.bit_endian != Endian.Big: + v <<= (n - nl) + val |= v + return val + else: + bs = self.handle.read(n) + if n > 0 and len(bs) != n: + raise EOFError + return bs + + def write(self, value: Union[bytes, int], *, bits: O[int] = None) -> None: + if bits is not None: + return + if self.bit_pos: + return + if isinstance(value, int): + raise TypeError + self.handle.write(value) + + def tell(self) -> Pos: + pos: Pos = self.handle.tell() + if self.bit_pos: + pos -= 1 + pos += self.bit_pos / 8 + return pos + + def seek(self, n: Pos, whence: int = os.SEEK_SET) -> None: + if isinstance(n, float): + bp = int((n % 1) * 8) + n = int(n) + else: + bp = 0 + self.handle.seek(n, whence) + if bp: + self.read_bits(bp) + +class Segment: + __slots__ = ('name', 'offset', 'dependents', 'pos') + + def __init__(self, name: str, dependents: Sequence['Segment'] = None) -> None: + self.name = name + self.offset: O[Pos] = None + self.dependents = dependents or [] + self.pos: O[Pos] = None + + def reset(self): + self.offset = self.pos = None + + def __repr__(self) -> str: + return f'<{__name__}.{self.__class__.__name__}: {self.name}>' + + +def process_sizes(s: Sequence[Mapping[Segment, Pos]], cb: Callable[[Pos, Pos], Pos]) -> Dict[Segment, O[Pos]]: + sizes: Dict[Segment, O[Pos]] = {} + for prev in s: + for k, n in prev.items(): + p = sizes.get(k, 0) + if p is None or n is None: + sizes[k] = None + else: + sizes[k] = cb(p, n) + return sizes + +def min_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]: + return process_sizes(s, min) + +def max_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]: + return process_sizes(s, max) + +def add_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]: + return process_sizes(s, lambda a, b: a + b) + +def ceil_sizes(s: Mapping[Segment, O[Pos]]) -> Dict[Segment, O[int]]: + d: Dict[Segment, O[int]] = {} + for k, v in s.items(): + if v is not None: + d[k] = math.ceil(v) + else: + d[k] = v + return d diff --git a/sx/core/meta.py b/sx/core/meta.py new file mode 100644 index 0000000..81478dc --- /dev/null +++ b/sx/core/meta.py @@ -0,0 +1,87 @@ +from typing import Optional as O, Generic as G, Sequence, TypeVar, Any +from .base import Type, Context, PathElement, Error +from .io import Stream +from . import to_type + + +T = TypeVar('T') + +class Wrapper(G[T], Type[T]): + def __init__(self, child: Type[T]) -> None: + self.child = child + + def parse(self, context: Context, stream: Stream) -> T: + return context.parse(to_type(self.child), stream) + + def dump(self, context: Context, stream: Stream, value: O[T]) -> None: + context.dump(to_type(self.child), stream, value) + + def sizeof(self, context: Context, value: O[T]) -> O[int]: + return context.sizeof(to_type(self.child), value) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[int]: + return context.offsetof(to_type(self.child), path, value) + + def default(self, context: Context) -> T: + return context.default(to_type(self.child)) + + def __str__(self) -> str: + return str(self.child) + + def __repr__(self) -> str: + return repr(self.child) + +class Generic(Type): + __slots__ = ('name', 'stack') + + def __init__(self, name: str) -> None: + self.name = name + self.stack = [] + + def push(self, value: Any) -> None: + if isinstance(value, Generic): + self.stack.append(value.stack[-1]) + else: + self.stack.append(value) + + def pop(self) -> None: + self.stack.pop() + + def _get_sx_type_(self, ident: Any) -> Type: + return to_type(self.stack[-1]) + + def parse(self, context: Context, stream: Stream) -> Any: + if not self.stack: + raise Error(context, 'unresolved generic') + return context.parse(to_type(self.stack[-1]), stream) + + def dump(self, context: Context, stream: Stream, value: O[Any]) -> None: + if not self.stack: + raise Error(context, 'unresolved generic') + context.dump(to_type(self.stack[-1]), stream, value) + + def sizeof(self, context: Context, value: O[Any]) -> O[int]: + if not self.stack: + return None + return context.sizeof(to_type(self.stack[-1]), value) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[Any]) -> O[int]: + if not self.stack: + return None + return context.offsetof(to_type(self.stack[-1]), path, value) + + def default(self, context: Context) -> Any: + if not self.stack: + raise Error(context, 'unresolved generic') + return context.default(to_type(self.stack[-1])) + + def __str__(self) -> str: + if self.stack: + return f'${self.name}:{to_type(self.stack[-1])}' + return f'${self.name}:unresolved' + + def __repr__(self) -> str: + return f'<{__name__}.{self.__class__.__name__}({self.stack!r})>' + + def __deepcopy__(self, memo: Any) -> Any: + return self diff --git a/sx/core/util.py b/sx/core/util.py new file mode 100644 index 0000000..4720f98 --- /dev/null +++ b/sx/core/util.py @@ -0,0 +1,67 @@ +import os +import math +import collections +from typing import BinaryIO, Generator, Callable, Union as U, Any, cast +from contextlib import contextmanager + + +def bits(v: int, s: int, l: int) -> int: + return (v >> s) & ((1 << l) - 1) + +def bit(v: int, s: int) -> int: + return bits(v, s, 1) + +Pos = U[int, float] + + +@contextmanager +def seeking(fd: BinaryIO, pos: Pos, whence: int = os.SEEK_SET) -> Generator[BinaryIO, None, None]: + oldpos = fd.tell() + fd.seek(cast(int, pos), whence) + try: + yield fd + finally: + fd.seek(oldpos, os.SEEK_SET) + + +def indent(s: str, count: int, start: bool = False) -> str: + """ Indent all lines of a string. """ + lines = s.splitlines() + for i in range(0 if start else 1, len(lines)): + lines[i] = ' ' * count + lines[i] + return '\n'.join(lines) + +def format_bytes(bs: bytes) -> str: + return '[' + ' '.join(hex(b)[2:].zfill(2) for b in bs) + ']' + +def format_value(value: Any, formatter: Callable[[Any], str], indentation: int = 0) -> str: + """ Format containers to use the given formatter function instead of always repr(). """ + if isinstance(value, (dict, collections.Mapping)): + if value: + fmt = '{{\n{}\n}}' + values = [indent(',\n'.join('{}: {}'.format( + format_value(k, formatter), + format_value(v, formatter) + ) for k, v in value.items()), 2, True)] + else: + fmt = '{{}}' + values = [] + elif isinstance(value, (list, set, frozenset)): + l = len(value) + is_set = isinstance(value, (set, frozenset)) + if l > 3: + fmt = '{{\n{}\n}}' if is_set else '[\n{}\n]' + values = [indent(',\n'.join(format_value(v, formatter) for v in value), 2, True)] + elif l > 0: + fmt = '{{{}}}' if is_set else '[{}]' + values = [', '.join(format_value(v, formatter) for v in value)] + else: + fmt = '{{}}' if is_set else '[]' + values = [] + elif isinstance(value, (bytes, bytearray)): + fmt = '{}' + values = [format_bytes(value)] + else: + fmt = '{}' + values = [formatter(value)] + return indent(fmt.format(*values), indentation) diff --git a/sx/types/__init__.py b/sx/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sx/types/data.py b/sx/types/data.py new file mode 100644 index 0000000..265936d --- /dev/null +++ b/sx/types/data.py @@ -0,0 +1,29 @@ +from typing import Optional as O, Union as U +from ..core.base import Type, Context, PossibleDynamic as D +from ..core.io import Stream + +class Data(Type[bytes]): + __slots__ = ('size',) + + def __init__(self, size: U[D, O[int]] = None) -> None: + self.size = size + + def parse(self, context: Context, stream: Stream) -> bytes: + size = context.get(self.size) + if size is None: + size = -1 + return stream.read(size) + + def dump(self, context: Context, stream: Stream, value: bytes) -> None: + stream.write(value) + + def default(self, context: Context) -> bytes: + size = context.peek(self.size) + if size is None: + size = 0 + return bytes(size) + + def sizeof(self, context: Context, value: O[bytes]) -> O[int]: + return context.peek(self.size) + +data = Data() diff --git a/sx/types/int.py b/sx/types/int.py new file mode 100644 index 0000000..ad1b930 --- /dev/null +++ b/sx/types/int.py @@ -0,0 +1,73 @@ +from typing import Optional as O, Union as U +from ..core.base import Type, Context, PossibleDynamic +from ..core.io import Stream, Endian +from .transforms import Mapped + + +class Int(Type[int]): + __slots__ = ('bits', 'endian', 'signed') + + def __init__(self, bits: PossibleDynamic[int], endian: PossibleDynamic[Endian], signed: PossibleDynamic[bool]) -> None: + self.bits = bits + self.endian = endian + self.signed = signed + + def parse(self, context: Context, stream: Stream) -> int: + n = context.get(self.bits) + bs = stream.read(n // 8) + return int.from_bytes(bs, byteorder=context.get(self.endian).to_python(), signed=context.get(self.signed)) + + def dump(self, context: Context, stream: Stream, value: U[int, float]) -> None: + if isinstance(value, float): + if value.is_integer(): + value = int(value) + else: + raise ValueError(f'can not encode float {value!r} as integer') + n = context.get(self.bits) + bs = value.to_bytes(n // 8, byteorder=context.get(self.endian).to_python(), signed=context.get(self.signed)) + return stream.write(bs) + + def default(self, context: Context) -> int: + return 0 + + def sizeof(self, context: Context, value: O[int]) -> O[int]: + size = context.peek(self.bits) + if size is not None: + size //= 8 + return size + + def __str__(self) -> str: + endian = {Endian.Big: 'be', Endian.Little: 'le'}.get(self.endian, self.endian) if self.bits != 8 else '' + sign = {True: '', False: 'u'}.get(self.signed, self.signed) + return f'{sign}int{self.bits}{endian}' + + def __repr__(self) -> str: + return f'<{__name__}.{self.__class__.__name__}({self.bits!r}, {self.endian!r}, signed: {self.signed!r})>' + +int8 = Int(8, endian=Endian.Little, signed=True) +uint8 = Int(8, endian=Endian.Little, signed=False) + +int16le = Int(16, endian=Endian.Little, signed=True) +int16be = Int(16, endian=Endian.Big, signed=True) +uint16le = Int(16, endian=Endian.Little, signed=False) +uint16be = Int(16, endian=Endian.Big, signed=False) + +int32le = Int(32, endian=Endian.Little, signed=True) +int32be = Int(32, endian=Endian.Big, signed=True) +uint32le = Int(32, endian=Endian.Little, signed=False) +uint32be = Int(32, endian=Endian.Big, signed=False) + +int64le = Int(64, endian=Endian.Little, signed=True) +int64be = Int(64, endian=Endian.Big, signed=True) +uint64le = Int(64, endian=Endian.Little, signed=False) +uint64be = Int(64, endian=Endian.Big, signed=False) + + +class Bool(Type[bool]): + def __new__(self, child: Type, true_value: int = 1, false_value: int = 0) -> Mapped: + return Mapped(child, {true_value: True, false_value: False}, + str='bool', + repr=f'<{__name__}.Bool({child!r}, true: {true_value!r}, false: {false_value!r})>', + ) + +bool = Bool(uint8) diff --git a/sx/types/seq.py b/sx/types/seq.py new file mode 100644 index 0000000..ecda4f1 --- /dev/null +++ b/sx/types/seq.py @@ -0,0 +1,200 @@ +from typing import Optional as O, Union as U, Callable, Any, List, Sequence, Mapping, Generic as G, TypeVar, Tuple as Tu +from types import FunctionType + +from ..core.base import PossibleDynamic as D, Type, Context, PathElement +from ..core.io import Stream, add_sizes +from ..core import to_type + + +T = TypeVar('T', bound=Type) + +class Arr(G[T], Type[List[T]]): + def __init__(self, child: D[T], count: O[D[int]] = None, stop: O[U[D[Any], Callable[[Any], bool]]] = None, include_stop: D[bool] = False) -> None: + self.child = child + self.count = count + self.stop = stop + self.include_stop = include_stop + + def parse(self, context: Context, stream: Stream) -> List[T]: + child = context.get(self.child) + count = context.get(self.count) + stop = context.get(self.stop) + include_stop = context.get(self.include_stop) + + value = [] + while True: + i = len(value) + if count is not None and i >= count: + break + c = to_type(child) + with context.enter(i, c): + try: + elem = context.parse(c, stream) + except EOFError: + if count is None: + break + raise + if stop is not None: + if isinstance(stop, FunctionType): + should_stop = stop(elem) + else: + should_stop = elem == stop + if should_stop: + if include_stop: + value.append(elem) + break + value.append(elem) + + return value + + def dump(self, context: Context, stream: Stream, value: List[T]) -> None: + child = context.get(self.child) + count = context.get(self.count) + stop = context.get(self.stop) + include_stop = context.get(self.include_stop) + + if stop is not None and not isinstance(stop, FunctionType) and not include_stop: + value += [stop] + + for i, elem in enumerate(value): + c = to_type(child) + with context.enter(i, c): + context.dump(child, stream, elem) + + context.put(self.count, len(value)) + + def get_sizes(self, context: Context, value: O[List[T]], n: int) -> Mapping[str, int]: + child = context.peek(self.child) + stop = context.peek(self.stop) + + sizes = [] + for i in range(n): + c = to_type(child) + if value is not None: + elem = value[i] + else: + elem = None + with context.enter(i, c): + size = context.sizeof(c, elem) + sizes.append(size) + + if stop is not None and not isinstance(stop, FunctionType): + sizes.append(context.sizeof(child, stop)) + + return sizes + + def sizeof(self, context: Context, value: O[List[T]]) -> O[Mapping[str, int]]: + if value is not None: + count = len(value) + else: + count = context.peek(self.count) + if count is None: + return None + sizes = self.get_sizes(context, value, count) + return add_sizes(*sizes) if sizes else 0 + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[List[T]]) -> O[int]: + if not path: + return 0 + + i = path[0] + path = path[1:] + if not isinstance(i, int): + raise ValueError('path element for array must be integer') + + child = context.peek(self.child) + sizes = self.get_sizes(context, value, i) + if path: + c = to_type(child) + with context.enter(i, c): + sizes.append(context.offsetof(c, path, value[i] if value is not None else None)) + return add_sizes(*sizes) if sizes else 0 + + def default(self, context: Context) -> List[T]: + child = context.peek(self.child) + count = context.peek(self.count) + if count is None or child is None: + value = [] + else: + value = [context.default(child) for _ in range(count)] + return value + + def __str__(self) -> str: + if self.count is not None: + count = repr(self.count) + else: + count = '' + return f'{self.child}[{count}]' + + def __repr__(self) -> str: + return f'<{__name__}.{self.__class__.__name__}({self.child!r}, count: {self.count!r}, stop: {self.stop!r}, include_stop: {self.include_stop!r})>' + + +class Tuple(Type): + def __init__(self, *children: Type) -> None: + self.children = children + + def parse(self, context: Context, stream: Stream) -> Tu: + values = [] + + for i, child in enumerate(self.children): + c = to_type(child) + with context.enter(i, c): + elem = context.parse(c, stream) + values.append(elem) + + return tuple(values) + + def dump(self, context: Context, stream: Stream, value: Tu) -> None: + for i, (child, elem) in enumerate(zip(self.children, value)): + c = to_type(child) + with context.enter(i, c): + context.dump(c, stream, elem) + + def sizeof(self, context: Context, value: O[Tu]) -> O[int]: + sizes = [] + + if value is None: + value = [None] * len(self.children) + for i, (child, elem) in enumerate(zip(self.children, value)): + c = to_type(child) + with context.enter(i, c): + sizes.append(context.sizeof(c, elem)) + + return add_sizes(*sizes) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[Tu]) -> O[int]: + if not path: + return 0 + + n = path[0] + path = path[1:] + if not isinstance(n, int): + raise ValueError('path element for tuple must be integer') + + sizes = [] + + if value is None: + value = [None] * len(self.children) + for i, (child, elem) in enumerate(zip(self.children, value)): + if i >= n: + break + c = to_type(child) + with context.enter(i, c): + sizes.append(context.sizeof(c, elem)) + + if path: + c = to_type(child) + with context.enter(n, c): + sizes.append(context.offsetof(c, path, elem)) + + return add_sizes(*sizes) + + def default(self, context: Context) -> Tu: + return tuple(context.default(c) for c in self.children) + + def __str__(self) -> str: + return '(' + ', '.join(str(c) for c in self.children) + ')' + + def __repr__(self) -> str: + return f'<{__name__}.{self.__class__.__name__}(' + ', '.join(repr(c) for c in self.children) + ')>' diff --git a/sx/types/str.py b/sx/types/str.py new file mode 100644 index 0000000..046d141 --- /dev/null +++ b/sx/types/str.py @@ -0,0 +1,54 @@ +import enum +from typing import Optional as O, Union as U +from ..core.base import PossibleDynamic as D, Type, Context +from ..core.io import Stream +from .int import uint8 + + +class StrType(enum.Enum): + Raw = enum.auto() + ZeroTerminated = C = enum.auto() + LengthPrefixed = Pascal = enum.auto() + +class Str(Type[str]): + def __init__(self, type: U[D, StrType], length: U[D, O[int]] = None, encoding: U[D, str] = 'utf-8', char_size: U[D, int] = 1, length_type: U[D, Type] = uint8, terminator: U[D, bytes] = b'\x00') -> None: + self.type = type + self.length = length + self.encoding = encoding + self.char_size = char_size + self.length_type = length_type + self.terminator = terminator + + def parse(self, context: Context, stream: Stream) -> str: + type = context.get(self.type) + length = context.get(self.length) + encoding = context.get(self.encoding) + char_size = context.get(self.char_size) + + if type == StrType.Raw: + if length is None: + raise ValueError('tried to parse raw string with no specified length') + data = stream.read(length * char_size) + elif type == StrType.C: + terminator = context.get(self.terminator) + data = b'' + while True: + d = stream.read(char_size) + if d == terminator: + break + data += d + if length is not None and len(data) >= length: + break + elif type == StrType.Pascal: + length_type = context.get(self.length_type) + plength = context.parse(length_type, stream) + return '' + + def dump(self, context: Context, stream: Stream, value: str) -> None: + pass + + def default(self, context: Context) -> str: + return '' + + def sizeof(self, context: Context, value: O[str]) -> O[int]: + return 0 diff --git a/sx/types/struct.py b/sx/types/struct.py new file mode 100644 index 0000000..b838001 --- /dev/null +++ b/sx/types/struct.py @@ -0,0 +1,292 @@ +from __future__ import annotations +import sys +import os +from typing import ( + Any, Callable, Iterator, Annotated, + Union as U, Optional as O, Generic as G, TypeVar, Type as Ty, + Tuple, List, Mapping, Sequence, +) +from contextlib import contextmanager + +import sx +from ..core import to_type +from ..core.base import Context, Type, PathElement +from ..core.io import Stream, Pos, add_sizes +from ..core.util import indent, format_value +from ..core.meta import Generic +from ..core.expr import VarExpr, VarSource + + +T = TypeVar('T') + +class StructType(G[T], Type[T]): + __slots__ = ('fields', 'cls', 'partial', 'union', 'generics', 'bound') + + def __init__(self, fields, cls: Ty[T], generics: Sequence[Generic] = (), union: bool = False, partial: bool = False, bound: Sequence[Type] = ()) -> None: + self.fields = fields + self.cls = cls + self.union = union + self.partial = partial + self.generics = generics + self.bound = bound + + def __getitem__(self, item: U[Type, Tuple[Type, ...]]) -> StructType[T]: + if not isinstance(item, tuple): + item = (item,) + + bound = self.bound + item + if len(bound) > len(self.generics): + raise TypeError('too many generics arguments for {}: {}'.format( + self.__class__.__name__, len(bound) + )) + + subtype = self.__class__(self.fields, self.cls, self.generics, self.union, self.partial, bound=bound) + return subtype + + @contextmanager + def enter(self): + for g, child in zip(self.generics, self.bound): + g.push(child) + yield + for g, _ in zip(self.generics, self.bound): + g.pop() + + def parse(self, context: Context, stream: Stream) -> T: + n: Pos = 0 + pos = stream.tell() + + c = self.cls.__new__(self.cls) + did_eof = False + with self.enter(): + for name, type in self.fields.items(): + if did_eof: + setattr(c, name, None) + continue + + with context.enter(name, type): + if type is None: + continue + if self.union: + stream.seek(pos, os.SEEK_SET) + + try: + val = context.parse(to_type(type), stream) + except EOFError: + if self.partial: + did_eof = True + setattr(c, name, None) + continue + raise + + nbytes = stream.tell() - pos + if self.union: + n = max(n, nbytes) + else: + n = nbytes + + setattr(c, name, val) + hook = 'on_parse_' + name + if hasattr(c, hook): + getattr(c, hook)(self.fields, context) + + stream.seek(pos + n, os.SEEK_SET) + return c + + def dump(self, context: Context, stream: Stream, value: T) -> None: + n: Pos = 0 + pos = stream.tell() + + with self.enter(): + for name, type in self.fields.items(): + with context.enter(name, type): + if self.union: + stream.seek(pos, os.SEEK_SET) + + hook = 'on_dump_' + name + if hasattr(value, hook): + getattr(value, hook)(self.fields, context) + + field = getattr(value, name) + context.dump(to_type(type), stream, field) + + nbytes = stream.tell() - pos + if self.union: + n = max(n, nbytes) + else: + n = nbytes + + stream.seek(pos + n, os.SEEK_SET) + + def default(self, context: Context) -> T: + return self.cls() + + def get_sizes(self, context: Context, value: O[Any], n: str) -> List[Mapping[str, int]]: + sizes = [] + for field, child in self.fields.items(): + if field == n: + break + if value is not None: + elem = getattr(value, field) + else: + elem = None + with context.enter(field, child): + size = context.sizeof(child, elem) + sizes.append(size) + return sizes + + def sizeof(self, context: Context, value: O[T]) -> O[Mapping[str, int]]: + with self.enter(): + sizes = self.get_sizes(context, value, None) + return add_sizes(*sizes) if sizes else 0 + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[int]: + if not path: + return 0 + + field = path[0] + path = path[1:] + if not isinstance(field, str): + raise ValueError('path element for struct must be string') + if field not in self.fields: + raise ValueError(f'field {field!r} invalid for {self.cls.__name__}') + + child = self.fields[field] + with self.enter(): + sizes = self.get_sizes(context, value, field) + if path: + with context.enter(field, child): + sizes.append(context.offsetof(child, path, getattr(value, field) if value is not None else None)) + return add_sizes(*sizes) if sizes else 0 + + def __str__(self) -> str: + if self.fields: + with self.enter(): + fields = '{\n' + for f, v in self.fields.items(): + fields += ' ' + f + ': ' + indent(format_value(to_type(v), str), 2) + ',\n' + fields += '}' + else: + fields = '{}' + return f'{self.cls.__name__} {fields}' + + def __repr__(self) -> str: + type = 'Union' if self.union else 'Struct' + if self.fields: + with self.enter(): + fields = '{\n' + for f, v in self.fields.items(): + fields += ' ' + f + ': ' + indent(format_value(to_type(v), repr), 2) + ',\n' + fields += '}' + else: + fields = '{}' + return f'<{__name__}.{type}({self.cls.__name__}) {fields}>' + +class Struct: + __slots__ = () + _sx_type_ = None + + def __init__(self, **kwargs) -> None: + super().__init__() + + st = self._sx_type_ + with st.enter(): + for k, t in st.fields.items(): + if k not in kwargs: + v = sx.default(t) + else: + v = kwargs.pop(k) + setattr(self, k, v) + for name, value in kwargs.items(): + setattr(self, name, value) + + def __init_subclass__(cls, *, inject: bool = True, generics: Sequence[Generic] = (), **kwargs: Any): + super().__init_subclass__() + + # Get all generics definition + parent: O[StructType] = getattr(cls, '_sx_type_', None) + bound: Sequence[Type] = () + generics = tuple(generics) + if parent: + generics = parent.generics + generics + bound = parent.bound + bound + + # Get all annotations + annots = {} + localns = {'Self': cls} + for c in reversed(cls.__mro__): + globalns = sys.modules[c.__module__].__dict__ + annots.update({k: (globalns, v) for k, v in getattr(c, '__annotations__', {}).items()}) + localns[c.__name__] = c + if inject: + localns.update({x: getattr(sx, x) for x in sx.__all__}) + localns.update({g.name: g for g in generics}) + + # Evaluate annotations into fields + fields = {} + refs: Mapping[str, VarExpr] = {} + for name, (globalns, value) in annots.items(): + val = eval(value, globalns, localns) + if isinstance(val, Annotated): + val = next(v for v in val.__metadata__ if isinstance(v, Type)) + fields[name] = val + localns[name] = refs[name] = VarExpr(name) + + del localns + for name, r in refs.items(): + count = sys.getrefcount(r) - 4 # cursed + if count: + fields[name] = VarSource(fields[name], count) + r._sx_resolve_(fields[name]) + + cls._sx_type_ = StructType(fields, cls, generics=generics, bound=bound, **kwargs) + + def __class_getitem__(cls, item) -> Type: + if not isinstance(item, tuple): + item = (item,) + subtype = cls._sx_type_[item] + new_name = '{}[{}]'.format(cls.__name__, ', '.join(str(g) for g in subtype.bound)) + new = type(new_name, (cls,), {}) + new._sx_type_ = subtype + new.__slots__ = cls.__slots__ + new.__module__ = cls.__module__ + subtype.cls = new + return new + + def __iter__(self) -> Iterator[Any]: + return iter(self._sx_type_.fields) + + def __hash__(self) -> int: + return hash(tuple((k, getattr(self, k)) for k in self)) + + def __eq__(self, other) -> bool: + if type(self) != type(other): + return False + if self.__slots__ != other.__slots__: + return False + for k in self: + ov = getattr(self, k) + tv = getattr(other, k) + if ov != tv: + return False + return True + + def _format_(self, fieldfunc: Callable[[Any], str]) -> str: + args = [] + for k in self: + if k.startswith('_'): + continue + val = getattr(self, k) + val = format_value(val, fieldfunc, 2) + args.append(' {}: {}'.format(k, val)) + args = ',\n'.join(args) + # Format final value. + if args: + return f'{self.__class__.__name__} {{\n{args}\n}}' + else: + return f'{self.__class__.__name__} {{}}' + + def __str__(self) -> str: + return self._format_(str) + + def __repr__(self) -> str: + return self._format_(repr) diff --git a/sx/types/transforms.py b/sx/types/transforms.py new file mode 100644 index 0000000..9dc37ce --- /dev/null +++ b/sx/types/transforms.py @@ -0,0 +1,196 @@ +import os +import errno +from typing import Any, Optional as O, Generic as G, Union as U, TypeVar, Callable, Sequence, Mapping +from ..core.base import Type, Context, PathElement, PossibleDynamic +from ..core.io import Stream, Segment, Pos +from ..core.meta import Wrapper +from ..core.util import seeking +from ..core import to_type + + +T = TypeVar('T') +V = TypeVar('V') + +class Default(G[T], Wrapper[T]): + def __init__(self, child: Type[T], default: T) -> None: + super().__init__(child) + self._default = default + + def default(self, context: Context) -> T: + return self._default + + +class SizedStream: + def __init__(self, stream: Stream, limit: Pos) -> None: + self._stream = stream + self._pos: Pos = 0 + self._limit = limit + self._start = stream.tell() + + def read(self, n: int = -1, bits=False) -> U[bytes, int]: + remaining = max(0, self._limit - self._pos) + if bits: + remaining *= 8 + if n < 0: + n = remaining + + if n > remaining: + raise EOFError + if bits: + self._pos += n / 8 + else: + self._pos += n + return self._stream.read(n, bits=bits) + + def write(self, data: U[bytes, int], *, bits: O[int] = None) -> None: + remaining = self._limit - self._pos + if bits is not None: + n = bits / 8 + else: + n = len(data) + if n > remaining: + raise EOFError + self._pos += n + return self._file.write(data, bits=bits) + + def seek(self, offset: Pos, whence: int) -> None: + if whence == os.SEEK_SET: + pos = offset + elif whence == os.SEEK_CUR: + pos = self._start + self._pos + offset + elif whence == os.SEEK_END: + pos = self._start + self._limit - offset + if pos < self._start: + raise OSError(errno.EINVAL, os.strerror(errno.EINVAL), offset) + self._pos = pos - self._start + return self._file.seek(pos, os.SEEK_SET) + + def tell(self) -> Pos: + return self._start + self._pos + + def __getattr__(self, attr: str) -> Any: + return getattr(self._stream, attr) + +class Sized(G[T], Wrapper[T]): + def __init__(self, child: Type[T], limit: U[Pos, PossibleDynamic]): + super().__init__(child) + self.limit = limit + + def parse(self, context: Context, stream: Stream) -> T: + limit = max(0, context.get(self.limit)) + start = stream.tell() + value = super().parse(context, SizedStream(stream, limit)) + stream.seek(start + limit) + return value + + def dump(self, context: Context, stream: Stream, value: O[T]) -> None: + limit = max(0, context.get(self.limit)) + start = stream.tell() + super().dump(context, SizedStream(stream, limit), value) + stream.seek(start + limit) + + def sizeof(self, context: Context, value: O[T]) -> O[Pos]: + return context.peek(self.limit) + + +class Ref(G[T], Wrapper[T]): + def __init__(self, child: Type[T], pos: U[PossibleDynamic, Pos], whence: int = os.SEEK_SET, segment: O[Segment] = None) -> None: + super().__init__(child) + self.pos = pos + self.whence = whence + self.segment = segment + + def parse(self, context: Context, stream: Stream) -> T: + point = context.get(self.pos) + whence = context.get(self.whence) + segment = context.get(self.segment) or context.params.segments['refs'] + with context.enter_segment(segment, stream, point, whence) as f: + return super().parse(context, f) + + def dump(self, context: Context, stream: Stream, value: T) -> None: + whence = context.get(self.whence) + segment = context.get(self.segment) or context.params.segments['refs'] + + with context.enter_segment(segment, stream) as f: + pos = f.tell() + super().dump(context, f, value) + + if whence == os.SEEK_CUR: + pos -= stream.tell() + elif whence == os.SEEK_END: + with seeking(stream, 0, os.SEEK_END) as f: + pos -= f.tell() + + context.put(self.pos, pos) + + def sizeof(self, context: Context, value: O[T]) -> O[Pos]: + segment = context.peek(self.segment) or context.params.segments['refs'] + with context.enter_segment(segment): + return super().sizeof(context, value) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[Pos]: + segment = context.peek(self.segment) or context.params.segments['refs'] + with context.enter_segment(segment): + return super().contextof(context, path, value) + + def __str__(self) -> str: + indicator = {os.SEEK_SET: '', os.SEEK_CUR: '+', os.SEEK_END: '-'}.get(self.whence, self.whence) + segment = f'{self.segment}:' if self.segment else '' + return f'&({super().__str__()} @ {indicator}{segment}{self.pos})' + + def __repr__(self) -> str: + indicator = {os.SEEK_SET: '', os.SEEK_CUR: '+', os.SEEK_END: '-'}.get(self.whence, self.whence) + segment = f'{self.segment!r}:' if self.segment else '' + return f'<{__name__}.{self.__class__.__name__}({super().__repr__()}, pos: {indicator}{segment}{self.pos!r})>' + +class Transform(G[T, V], Type[V]): + def __init__(self, child: Type[T], parse: U[Callable[[T], V], Callable[[T, Context], V]], dump: U[Callable[[V], T], Callable[[V, Context], T]], context: bool = False, str: O[str] = None, repr: O[str] = None) -> None: + self.child = child + self.on_parse = parse + self.on_dump = dump + self.context = context + self.on_str = str or repr + self.on_repr = repr or str + + def parse(self, context: Context, stream: Stream) -> V: + value = context.parse(to_type(self.child), stream) + return self.on_parse(value, context) if self.context else self.on_parse(value) + + def dump(self, context: Context, stream: Stream, value: V) -> None: + value = self.on_dump(value, context) if self.context else self.on_dump(value) + context.dump(to_type(self.child), stream, value) + + def sizeof(self, context: Context, value: O[V]) -> O[int]: + if value is not None: + value = self.on_dump(value, context) if self.context else self.on_dump(value) + return context.sizeof(to_type(self.child), value) + + def offsetof(self, context: Context, path: Sequence[PathElement], value: O[V]) -> O[int]: + if value is not None: + value = self.on_dump(value, context) if self.context else self.on_dump(value) + return context.offsetof(to_type(self.child), path, value) + + def default(self, context: Context) -> V: + value = context.default(to_type(self.child)) + return self.on_parse(value, context) if self.context else self.on_parse(value) + + def __str__(self) -> str: + if self.on_str is not None: + return self.on_str + return f'λ({self.child})' + + def __repr__(self) -> str: + if self.on_repr is not None: + return self.on_repr + return f'<{__name__}.{self.__class__.__name__}({self.child!r}, parse: {self.on_parse!r}, dump: {self.on_dump!r})>' + + +class Mapped(G[T, V], Type[T]): + def __new__(self, child: Type[T], mapping: Mapping[T, V], str: O[str] = None, repr: O[str] = None) -> Transform: + reverse = {v: k for k, v in mapping.items()} + return Transform(child, + parse=mapping.__getitem__, + dump=reverse.__getitem__, + str=str or f'{mapping}[{child}]', + repr=repr or f'<{__name__}.Mapped({child!r}, {mapping!r}' + )