sx/sx/core/io.py

167 lines
4.6 KiB
Python

import os
import enum
import math
from contextlib import contextmanager
from typing import BinaryIO, Sequence, Optional as O, Union, Mapping, Callable, Dict
from .util import bits
class BitAlignment(enum.Enum):
No = enum.auto()
Fill = enum.auto()
Yes = enum.auto()
class Endian(enum.Enum):
Little = enum.auto()
Big = enum.auto()
def to_python(self):
return {self.Little: 'little', self.Big: 'big'}[self]
Pos = Union[int, float]
class Stream:
def __init__(self, handle: BinaryIO, bit_align: BitAlignment = BitAlignment.No, bit_endian: Endian = Endian.Big) -> None:
self.handle = handle
self.bit_pos: O[int] = None
self.bit_val: O[int] = None
self.bit_align = bit_align
self.bit_endian = bit_endian
@property
def root(self):
h = self.handle
while isinstance(h, self.__class__):
h = self.handle
return h
@contextmanager
def wrapped(self, handle):
#self.flush()
old = self.handle
self.handle = handle
yield self
self.handle = old
def read_bits(self, n: int):
if n <= 0:
return (0, 0)
if self.bit_pos is None or self.bit_val is None:
self.bit_val = self.read(1)[0]
self.bit_pos = 0
nb = min(8 - self.bit_pos, n)
if self.bit_endian == Endian.Big:
val = bits(self.bit_val, self.bit_pos, nb)
else:
val = bits(self.bit_val, 8 - self.bit_pos - nb, nb)
self.bit_pos += nb
if self.bit_pos == 8:
self.bit_pos = self.bit_val = None
return val, n - nb
def read(self, n: int = -1, bits=False) -> bytes:
if bits:
val, nl = self.read_bits(n)
if self.bit_endian == Endian.Big:
val <<= nl
if nl >= 8:
rounds = nl // 8
v = int.from_bytes(self.read(rounds), byteorder=self.bit_endian.to_python())
if self.bit_endian == Endian.Big:
nl -= rounds * 8
v <<= nl
else:
v <<= (n - nl)
nl -= rounds * 8
val |= v
if nl > 0:
v, _ = self.read_bits(nl)
if self.bit_endian != Endian.Big:
v <<= (n - nl)
val |= v
return val
else:
bs = self.handle.read(n)
if n > 0 and len(bs) != n:
raise EOFError
return bs
def write(self, value: Union[bytes, int], *, bits: O[int] = None) -> None:
if bits is not None:
return
if self.bit_pos:
return
if isinstance(value, int):
raise TypeError
self.handle.write(value)
def tell(self) -> Pos:
pos: Pos = self.handle.tell()
if self.bit_pos:
pos -= 1
pos += self.bit_pos / 8
return pos
def seek(self, n: Pos, whence: int = os.SEEK_SET) -> None:
if isinstance(n, float):
bp = int((n % 1) * 8)
n = int(n)
else:
bp = 0
self.handle.seek(n, whence)
if bp:
self.read_bits(bp)
class Segment:
__slots__ = ('name', 'offset', 'dependents', 'pos')
def __init__(self, name: str, dependents: Sequence['Segment'] = None) -> None:
self.name = name
self.offset: O[Pos] = None
self.dependents = dependents or []
self.pos: O[Pos] = None
def reset(self):
self.offset = self.pos = None
def __repr__(self) -> str:
return f'<{__name__}.{self.__class__.__name__}: {self.name}>'
def process_sizes(s: Sequence[Mapping[Segment, Pos]], cb: Callable[[Pos, Pos], Pos]) -> Dict[Segment, O[Pos]]:
sizes: Dict[Segment, O[Pos]] = {}
for prev in s:
for k, n in prev.items():
p = sizes.get(k, 0)
if p is None or n is None:
sizes[k] = None
else:
sizes[k] = cb(p, n)
return sizes
def min_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]:
return process_sizes(s, min)
def max_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]:
return process_sizes(s, max)
def add_sizes(*s: Mapping[Segment, Pos]) -> Dict[Segment, O[Pos]]:
return process_sizes(s, lambda a, b: a + b)
def ceil_sizes(s: Mapping[Segment, O[Pos]]) -> Dict[Segment, O[int]]:
d: Dict[Segment, O[int]] = {}
for k, v in s.items():
if v is not None:
d[k] = math.ceil(v)
else:
d[k] = v
return d