359 lines
12 KiB
Python
359 lines
12 KiB
Python
from __future__ import annotations
|
|
import sys
|
|
import os
|
|
from typing import (
|
|
Any, Callable, Iterator, Annotated,
|
|
Union as U, Optional as O, Generic as G, TypeVar, Type as Ty,
|
|
Tuple, List, Mapping, Sequence,
|
|
)
|
|
from contextlib import contextmanager
|
|
|
|
import sx
|
|
from ..core import to_type
|
|
from ..core.base import Context, Type, PathElement
|
|
from ..core.io import Stream, Pos, PosInfo, add_sizes, max_sizes
|
|
from ..core.util import indent, format_value, get_annot_locations
|
|
from ..core.meta import Generic, TypeSource
|
|
from ..core.expr import ProxyExpr
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
class ProxyStruct:
|
|
def __init__(self):
|
|
self._sx_fields_ = {}
|
|
|
|
def __getattr__(self, name: str) -> ProxyExpr:
|
|
e = ProxyExpr(name, 'self')
|
|
try:
|
|
self._sx_fields_[name].append(e)
|
|
except KeyError:
|
|
raise AttributeError(name) from None
|
|
return e
|
|
|
|
|
|
class StructType(G[T], Type[T]):
|
|
__slots__ = ('fields', 'cls', 'partial', 'union', 'generics', 'bound')
|
|
|
|
def __init__(self, fields, cls: Ty[T], generics: Sequence[Generic] = (), union: bool = False, partial: bool = False, bound: Sequence[Type] = ()) -> None:
|
|
self.fields = fields
|
|
self.cls = cls
|
|
self.union = union
|
|
self.partial = partial
|
|
self.generics = generics
|
|
self.bound = bound
|
|
|
|
def __getitem__(self, item: U[Type, Tuple[Type, ...]]) -> StructType[T]:
|
|
if not isinstance(item, tuple):
|
|
item = (item,)
|
|
|
|
bound = self.bound + item
|
|
if len(bound) > len(self.generics):
|
|
raise TypeError('too many generics arguments for {}: {}'.format(
|
|
self.__class__.__name__, len(bound)
|
|
))
|
|
|
|
subtype = self.__class__(self.fields, self.cls, self.generics, self.union, self.partial, bound=bound)
|
|
return subtype
|
|
|
|
@contextmanager
|
|
def enter(self):
|
|
for g, child in zip(self.generics, self.bound):
|
|
g.push(child)
|
|
yield
|
|
for g, _ in zip(self.generics, self.bound):
|
|
g.pop()
|
|
|
|
def parse(self, context: Context, stream: Stream) -> T:
|
|
n: Pos = 0
|
|
pos = stream.tell()
|
|
|
|
c = self.cls.__new__(self.cls)
|
|
did_eof = False
|
|
with self.enter():
|
|
for name, type in self.fields.items():
|
|
if did_eof:
|
|
setattr(c, name, None)
|
|
continue
|
|
|
|
with context.enter(name, type):
|
|
if type is None:
|
|
continue
|
|
if self.union:
|
|
stream.seek(pos, os.SEEK_SET)
|
|
|
|
try:
|
|
val = context.parse(to_type(type), stream)
|
|
except EOFError:
|
|
if self.partial:
|
|
did_eof = True
|
|
setattr(c, name, None)
|
|
continue
|
|
raise
|
|
|
|
nbytes = stream.tell() - pos
|
|
if self.union:
|
|
n = max(n, nbytes)
|
|
else:
|
|
n = nbytes
|
|
|
|
setattr(c, name, val)
|
|
hook = 'on_parse_' + name
|
|
if hasattr(c, hook):
|
|
getattr(c, hook)(self.fields, context)
|
|
|
|
stream.seek(pos + n, os.SEEK_SET)
|
|
return c
|
|
|
|
def dump(self, context: Context, stream: Stream, value: T) -> None:
|
|
n: Pos = 0
|
|
pos = stream.tell()
|
|
|
|
with self.enter():
|
|
if self.union:
|
|
if value._sx_lastset_:
|
|
fields = [value._sx_lastset_]
|
|
else:
|
|
fields = next(self.fields)
|
|
else:
|
|
fields = list(self.fields)
|
|
|
|
for name in fields:
|
|
type = self.fields[name]
|
|
with context.enter(name, type):
|
|
if self.union:
|
|
stream.seek(pos, os.SEEK_SET)
|
|
|
|
hook = 'on_dump_' + name
|
|
if hasattr(value, hook):
|
|
getattr(value, hook)(self.fields, context)
|
|
|
|
field = getattr(value, name)
|
|
context.dump(to_type(type), stream, field)
|
|
|
|
nbytes = stream.tell() - pos
|
|
if self.union:
|
|
n = max(n, nbytes)
|
|
else:
|
|
n = nbytes
|
|
|
|
stream.seek(pos + n, os.SEEK_SET)
|
|
|
|
def default(self, context: Context) -> T:
|
|
return self.cls(_sx_context_=context)
|
|
|
|
def get_sizes(self, context: Context, start: PosInfo, value: O[Any], n: str) -> Tuple[PosInfo, List[PosInfo]]:
|
|
sizes = []
|
|
for field, child in self.fields.items():
|
|
if field == n:
|
|
break
|
|
if value is not None:
|
|
elem = getattr(value, field)
|
|
else:
|
|
elem = None
|
|
c = to_type(child, field)
|
|
with context.enter(field, c):
|
|
size = context.sizeof(c, start, elem)
|
|
if not self.union:
|
|
start = add_sizes(start, size)
|
|
sizes.append(size)
|
|
return start, sizes
|
|
|
|
def sizeof(self, context: Context, start: PosInfo, value: O[T]) -> PosInfo:
|
|
with self.enter():
|
|
_, sizes = self.get_sizes(context, start, value, None)
|
|
if sizes:
|
|
if self.union:
|
|
return max_sizes(*sizes)
|
|
else:
|
|
return add_sizes(*sizes)
|
|
else:
|
|
return 0
|
|
|
|
def offsetof(self, context: Context, start: PosInfo, path: Sequence[PathElement], value: O[T]) -> O[int]:
|
|
if not path:
|
|
return 0
|
|
|
|
field = path[0]
|
|
path = path[1:]
|
|
if not isinstance(field, str):
|
|
raise ValueError('path element for struct must be string')
|
|
if field not in self.fields:
|
|
raise ValueError(f'field {field!r} invalid for {self.cls.__name__}')
|
|
|
|
child = self.fields[field]
|
|
with self.enter():
|
|
if self.union:
|
|
sizes = []
|
|
else:
|
|
start, sizes = self.get_sizes(context, start, value, field)
|
|
if path:
|
|
with context.enter(field, child):
|
|
sizes.append(context.offsetof(child, start, path, getattr(value, field) if value is not None else None))
|
|
return add_sizes(*sizes) if sizes else 0
|
|
|
|
def __str__(self) -> str:
|
|
if self.fields:
|
|
with self.enter():
|
|
fields = '{\n'
|
|
for f, v in self.fields.items():
|
|
fields += ' ' + f + ': ' + indent(format_value(to_type(v), str), 2) + ',\n'
|
|
fields += '}'
|
|
else:
|
|
fields = '{}'
|
|
return f'{self.cls.__name__} {fields}'
|
|
|
|
def __repr__(self) -> str:
|
|
type = 'Union' if self.union else 'Struct'
|
|
if self.fields:
|
|
with self.enter():
|
|
fields = '{\n'
|
|
for f, v in self.fields.items():
|
|
fields += ' ' + f + ': ' + indent(format_value(to_type(v), repr), 2) + ',\n'
|
|
fields += '}'
|
|
else:
|
|
fields = '{}'
|
|
return f'<{__name__}.{type}({self.cls.__name__}) {fields}>'
|
|
|
|
class Struct:
|
|
__slots__ = ()
|
|
_sx_type_ = None
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
super().__init__()
|
|
|
|
ctx = kwargs.pop('_sx_context_', None)
|
|
|
|
st = self._sx_type_
|
|
with st.enter():
|
|
for k, t in st.fields.items():
|
|
if k not in kwargs:
|
|
if ctx:
|
|
with ctx.enter(k, t):
|
|
v = ctx.default(to_type(t))
|
|
else:
|
|
v = sx.default(t)
|
|
else:
|
|
v = kwargs.pop(k)
|
|
setattr(self, k, v)
|
|
if kwargs:
|
|
raise AttributeError(', '.join(kwargs))
|
|
|
|
def __init_subclass__(cls, *, inject: bool = True, generics: Sequence[Generic] = (), **kwargs: Any) -> None:
|
|
super().__init_subclass__()
|
|
|
|
# Get all generics definition
|
|
parent: O[StructType] = getattr(cls, '_sx_type_', None)
|
|
bound: Sequence[Type] = ()
|
|
generics = tuple(generics)
|
|
if parent:
|
|
generics = parent.generics + generics
|
|
bound = parent.bound + bound
|
|
kwargs.setdefault('union', parent.union)
|
|
|
|
# Get all annotations
|
|
annots = {}
|
|
localns = {'Self': cls}
|
|
for c in reversed(cls.__mro__):
|
|
try:
|
|
fn, lines = get_annot_locations(c)
|
|
except:
|
|
fn = None
|
|
lines = {}
|
|
globalns = sys.modules[c.__module__].__dict__
|
|
annots.update({
|
|
k: (fn or f'<annotation:{k}>', lines.get(k, 0), globalns, v)
|
|
for k, v in getattr(c, '__annotations__', {}).items()
|
|
})
|
|
localns[c.__name__] = c
|
|
if inject:
|
|
localns.update({x: getattr(sx, x) for x in sx.__all__})
|
|
localns.update({g.name: g for g in generics})
|
|
|
|
# Evaluate annotations into fields
|
|
proxy = ProxyStruct()
|
|
localns['self'] = proxy
|
|
fields = {}
|
|
for name, (fn, line, globalns, value) in annots.items():
|
|
code = compile('\n' * line + value, fn, 'eval')
|
|
val = eval(code, globalns, localns)
|
|
if isinstance(val, Annotated):
|
|
val = next(v for v in val.__metadata__ if isinstance(v, Type))
|
|
fields[name] = val
|
|
proxy._sx_fields_[name] = []
|
|
|
|
del localns
|
|
for name, exprs in proxy._sx_fields_.items():
|
|
count = len(exprs)
|
|
if count:
|
|
fields[name] = TypeSource(fields[name], count)
|
|
for e in exprs:
|
|
e._sx_push_(fields[name])
|
|
|
|
cls._sx_type_ = StructType(fields, cls, generics=generics, bound=bound, **kwargs)
|
|
|
|
def __class_getitem__(cls, item) -> Type:
|
|
if not isinstance(item, tuple):
|
|
item = (item,)
|
|
subtype = cls._sx_type_[item]
|
|
new_name = '{}[{}]'.format(cls.__name__, ', '.join(str(g) for g in subtype.bound))
|
|
new = type(new_name, (cls,), {
|
|
'__module__': cls.__module__,
|
|
'__slots__': cls.__slots__,
|
|
})
|
|
new._sx_type_ = subtype
|
|
subtype.cls = new
|
|
return new
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
return iter(self._sx_type_.fields)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(tuple((k, getattr(self, k)) for k in self))
|
|
|
|
def __eq__(self, other) -> bool:
|
|
if type(self) != type(other):
|
|
return False
|
|
if self.__slots__ != other.__slots__:
|
|
return False
|
|
for k in self:
|
|
ov = getattr(self, k)
|
|
tv = getattr(other, k)
|
|
if ov != tv:
|
|
return False
|
|
return True
|
|
|
|
def _format_(self, fieldfunc: Callable[[Any], str]) -> str:
|
|
args = []
|
|
for k in self:
|
|
if k.startswith('_'):
|
|
continue
|
|
val = getattr(self, k)
|
|
val = format_value(val, fieldfunc, 2)
|
|
args.append(' {}: {}'.format(k, val))
|
|
args = ',\n'.join(args)
|
|
# Format final value.
|
|
if args:
|
|
return f'{self.__class__.__name__} {{\n{args}\n}}'
|
|
else:
|
|
return f'{self.__class__.__name__} {{}}'
|
|
|
|
def __str__(self) -> str:
|
|
return self._format_(str)
|
|
|
|
def __repr__(self) -> str:
|
|
return self._format_(repr)
|
|
|
|
|
|
class Union(Struct, union=True, inject=False):
|
|
_sx_lastset_ = ''
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
if kwargs:
|
|
super().__setattr__('_sx_lastset_', list(kwargs)[-1])
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
super().__setattr__(name, value)
|
|
super().__setattr__('_sx_lastset_', name)
|