260 lines
7.5 KiB
Python
260 lines
7.5 KiB
Python
import os
|
|
import enum
|
|
import math
|
|
from io import BytesIO
|
|
from contextlib import contextmanager
|
|
from typing import BinaryIO, Sequence, Optional as O, Union, Mapping, Callable, Dict, Tuple
|
|
|
|
from .util import bits
|
|
|
|
|
|
class BitAlignment(enum.Enum):
|
|
No = enum.auto()
|
|
Fill = enum.auto()
|
|
Yes = enum.auto()
|
|
|
|
class Endian(enum.Enum):
|
|
Little = enum.auto()
|
|
Big = enum.auto()
|
|
|
|
def to_python(self) -> str:
|
|
return {self.Little: 'little', self.Big: 'big'}[self]
|
|
|
|
def to_struct(self) -> str:
|
|
return {self.Little: '<', self.Big: '>'}[self]
|
|
|
|
Pos = Union[int, float]
|
|
|
|
|
|
class Stream:
|
|
__slots__ = ('handle', 'bit_pos', 'bit_val', 'bit_align', 'bit_endian', 'bit_dirty')
|
|
|
|
def __init__(self, handle: BinaryIO, bit_align: BitAlignment = BitAlignment.No, bit_endian: Endian = Endian.Big) -> None:
|
|
self.handle = handle
|
|
self.bit_pos: O[int] = None
|
|
self.bit_val: O[int] = None
|
|
self.bit_align = bit_align
|
|
self.bit_endian = bit_endian
|
|
self.bit_dirty = False
|
|
|
|
@property
|
|
def root(self):
|
|
h = self.handle
|
|
while isinstance(h, self.__class__):
|
|
h = self.handle
|
|
return h
|
|
|
|
@contextmanager
|
|
def wrapped(self, handle):
|
|
#self.flush()
|
|
old = self.handle
|
|
self.handle = handle
|
|
yield self
|
|
self.handle = old
|
|
|
|
|
|
def flush_bits(self) -> None:
|
|
if self.bit_dirty:
|
|
self.handle.seek(-1, os.SEEK_CUR)
|
|
self.handle.write(bytes([self.bit_val]))
|
|
self.bit_dirty = False
|
|
|
|
def flush(self) -> None:
|
|
self.flush_bits()
|
|
self.handle.flush()
|
|
|
|
|
|
def read_bits(self, n: int) -> Tuple[int, int]:
|
|
if n <= 0:
|
|
return (0, 0)
|
|
|
|
if self.bit_pos is None or self.bit_val is None:
|
|
try:
|
|
self.bit_val = self.read(1)[0]
|
|
except IndexError:
|
|
raise EOFError
|
|
self.bit_pos = 0
|
|
|
|
nb = min(8 - self.bit_pos, n)
|
|
if self.bit_endian == Endian.Big:
|
|
val = bits(self.bit_val, self.bit_pos, nb)
|
|
else:
|
|
val = bits(self.bit_val, 8 - nb - self.bit_pos, nb)
|
|
|
|
self.bit_pos += nb
|
|
if self.bit_pos == 8:
|
|
self.flush_bits()
|
|
self.bit_pos = self.bit_val = None
|
|
|
|
return val, n - nb
|
|
|
|
def read(self, n: int = -1, bits=False) -> bytes:
|
|
if bits:
|
|
val, nl = self.read_bits(n)
|
|
if self.bit_endian == Endian.Big:
|
|
val <<= nl
|
|
if nl >= 8:
|
|
rounds = nl // 8
|
|
v = int.from_bytes(self.read(rounds), byteorder=self.bit_endian.to_python())
|
|
if self.bit_endian == Endian.Big:
|
|
nl -= rounds * 8
|
|
v <<= nl
|
|
else:
|
|
v <<= (n - nl)
|
|
nl -= rounds * 8
|
|
val |= v
|
|
if nl > 0:
|
|
v, _ = self.read_bits(nl)
|
|
if self.bit_endian != Endian.Big:
|
|
v <<= (n - nl)
|
|
val |= v
|
|
return val
|
|
elif self.bit_pos is not None:
|
|
if self.bit_align == BitAlignment.No:
|
|
raise IOError('unaligned read')
|
|
elif self.bit_align == BitAlignment.Fill:
|
|
self.read_bits(8 - self.bit_pos)
|
|
return self.read(n)
|
|
elif self.bit_align == BitAlignment.Yes:
|
|
bs = bytearray()
|
|
while n < 0 or len(bs) < n:
|
|
bs.append(self.read(8, bits=True))
|
|
return bytes(bs)
|
|
else:
|
|
bs = self.handle.read(n)
|
|
if n > 0 and len(bs) != n:
|
|
raise EOFError
|
|
return bs
|
|
|
|
def write_bits(self, value: int, n: int) -> Tuple[int, int]:
|
|
if n <= 0:
|
|
return (value, n)
|
|
|
|
if self.bit_pos is None or self.bit_val is None:
|
|
try:
|
|
self.bit_val = 0
|
|
self.write(b'\x00')
|
|
except IndexError:
|
|
raise EOFError
|
|
self.bit_pos = 0
|
|
|
|
nb = min(8 - self.bit_pos, n)
|
|
if self.bit_endian == Endian.Big:
|
|
val = bits(value, n - nb, nb)
|
|
self.bit_val |= val << self.bit_pos
|
|
else:
|
|
val = bits(value, 0, nb)
|
|
self.bit_val |= val << (8 - nb - self.bit_pos)
|
|
|
|
self.bit_dirty = True
|
|
self.bit_pos += nb
|
|
if self.bit_pos == 8:
|
|
self.flush_bits()
|
|
self.bit_pos = self.bit_val = None
|
|
|
|
return value, n - nb
|
|
|
|
def write(self, value: Union[bytes, int], *, bits: O[int] = None) -> None:
|
|
if bits is not None:
|
|
value, nl = self.write_bits(value, bits)
|
|
if self.bit_endian == Endian.Little:
|
|
value >>= (bits - nl)
|
|
if nl >= 8:
|
|
rounds = nl // 8
|
|
self.write(value.to_bytes(rounds, byteorder=self.bit_endian.to_python()))
|
|
nl -= rounds * 8
|
|
if self.bit_endian == Endian.Little:
|
|
value >>= rounds * 8
|
|
if nl > 0:
|
|
v, _ = self.write_bits(value, nl)
|
|
elif self.bit_pos is not None:
|
|
if self.bit_align == BitAlignment.No:
|
|
raise IOError('unaligned write')
|
|
elif self.bit_align == BitAlignment.Fill:
|
|
self.write_bits(0, 8 - self.bit_pos)
|
|
return self.write(value)
|
|
elif self.bit_align == BitAlignment.Yes:
|
|
for b in value:
|
|
self.write(b, bits=8)
|
|
else:
|
|
self.handle.write(value)
|
|
|
|
|
|
def tell(self) -> Pos:
|
|
pos: Pos = self.handle.tell()
|
|
if self.bit_pos:
|
|
pos -= 1
|
|
pos += self.bit_pos / 8
|
|
return pos
|
|
|
|
def seek(self, n: Pos, whence: int = os.SEEK_SET) -> None:
|
|
self.flush_bits()
|
|
if isinstance(n, float):
|
|
bp = int((n % 1) * 8)
|
|
n = int(n)
|
|
else:
|
|
bp = 0
|
|
self.handle.seek(n, whence)
|
|
if bp:
|
|
self.read_bits(bp)
|
|
|
|
|
|
class Segment:
|
|
__slots__ = ('name', 'offset', 'dependents', 'pos')
|
|
|
|
def __init__(self, name: str, dependents: Sequence['Segment'] = None) -> None:
|
|
self.name = name
|
|
self.offset: O[Pos] = None
|
|
self.dependents = dependents or []
|
|
self.pos: O[Pos] = None
|
|
|
|
def reset(self):
|
|
self.offset = self.pos = None
|
|
|
|
def __repr__(self) -> str:
|
|
return f'<{__name__}.{self.__class__.__name__}: {self.name}>'
|
|
|
|
|
|
PosInfo = Mapping[Segment, O[Pos]]
|
|
|
|
def process_sizes(s: Sequence[PosInfo], cb: Callable[[Pos, Pos], Pos]) -> PosInfo:
|
|
sizes: Dict[Segment, O[Pos]] = {}
|
|
for prev in s:
|
|
for k, n in prev.items():
|
|
p = sizes.get(k, 0)
|
|
if p is None or n is None:
|
|
sizes[k] = None
|
|
else:
|
|
sizes[k] = cb(p, n)
|
|
return sizes
|
|
|
|
def min_sizes(*s: PosInfo) -> PosInfo:
|
|
return process_sizes(s, min)
|
|
|
|
def max_sizes(*s: PosInfo) -> PosInfo:
|
|
return process_sizes(s, max)
|
|
|
|
def add_sizes(*s: PosInfo) -> PosInfo:
|
|
return process_sizes(s, lambda a, b: a + b)
|
|
|
|
def ceil_sizes(s: PosInfo) -> PosInfo:
|
|
d: PosInfo = {}
|
|
for k, v in s.items():
|
|
if v is not None:
|
|
d[k] = math.ceil(v)
|
|
else:
|
|
d[k] = v
|
|
return d
|
|
|
|
|
|
PossibleStream = Union[BinaryIO, Stream, None, bytes, bytearray]
|
|
|
|
def to_stream(value: PossibleStream) -> Stream:
|
|
if isinstance(value, Stream):
|
|
return value
|
|
if value is None:
|
|
value = BytesIO()
|
|
if isinstance(value, (bytes, bytearray)):
|
|
value = BytesIO(value)
|
|
return Stream(value)
|