Compare commits
2 Commits
d9b7849d5c
...
dcdb37b0f4
Author | SHA1 | Date |
---|---|---|
Shiz | dcdb37b0f4 | |
Shiz | bfb2a17b71 |
|
@ -18,7 +18,7 @@ from .types.num import (
|
|||
)
|
||||
from .types.str import Str, CStr, cstr, wcstr, utf8cstr, utf16cstr
|
||||
from .types.seq import Arr, Tuple
|
||||
from .types.struct import StructType, Struct
|
||||
from .types.struct import StructType, Struct, Union
|
||||
from .types.transforms import Default, Transform, Mapped, Enum, Check, Fixed
|
||||
from .types.control import Switch, If
|
||||
from .types.io import Sized, Terminated, Ref, AlignTo, AlignedTo
|
||||
|
|
|
@ -11,7 +11,7 @@ from contextlib import contextmanager
|
|||
import sx
|
||||
from ..core import to_type
|
||||
from ..core.base import Context, Type, PathElement
|
||||
from ..core.io import Stream, Pos, add_sizes
|
||||
from ..core.io import Stream, Pos, add_sizes, max_sizes
|
||||
from ..core.util import indent, format_value, get_annot_locations
|
||||
from ..core.meta import Generic, TypeSource
|
||||
from ..core.expr import ProxyExpr
|
||||
|
@ -110,7 +110,16 @@ class StructType(G[T], Type[T]):
|
|||
pos = stream.tell()
|
||||
|
||||
with self.enter():
|
||||
for name, type in self.fields.items():
|
||||
if self.union:
|
||||
if value._sx_lastset_:
|
||||
fields = [value._sx_lastset_]
|
||||
else:
|
||||
fields = next(self.fields)
|
||||
else:
|
||||
fields = list(self.fields)
|
||||
|
||||
for name in fields:
|
||||
type = self.fields[name]
|
||||
with context.enter(name, type):
|
||||
if self.union:
|
||||
stream.seek(pos, os.SEEK_SET)
|
||||
|
@ -151,7 +160,13 @@ class StructType(G[T], Type[T]):
|
|||
def sizeof(self, context: Context, value: O[T]) -> O[Mapping[str, int]]:
|
||||
with self.enter():
|
||||
sizes = self.get_sizes(context, value, None)
|
||||
return add_sizes(*sizes) if sizes else 0
|
||||
if sizes:
|
||||
if self.union:
|
||||
return max_sizes(*sizes)
|
||||
else:
|
||||
return add_sizes(*sizes)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[int]:
|
||||
if not path:
|
||||
|
@ -166,7 +181,10 @@ class StructType(G[T], Type[T]):
|
|||
|
||||
child = self.fields[field]
|
||||
with self.enter():
|
||||
sizes = self.get_sizes(context, value, field)
|
||||
if self.union:
|
||||
sizes = []
|
||||
else:
|
||||
sizes = self.get_sizes(context, value, field)
|
||||
if path:
|
||||
with context.enter(field, child):
|
||||
sizes.append(context.offsetof(child, path, getattr(value, field) if value is not None else None))
|
||||
|
@ -223,6 +241,7 @@ class Struct:
|
|||
if parent:
|
||||
generics = parent.generics + generics
|
||||
bound = parent.bound + bound
|
||||
kwargs.setdefault('union', parent.union)
|
||||
|
||||
# Get all annotations
|
||||
annots = {}
|
||||
|
@ -315,3 +334,16 @@ class Struct:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return self._format_(repr)
|
||||
|
||||
|
||||
class Union(Struct, union=True, inject=False):
|
||||
_sx_lastset_ = ''
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if kwargs:
|
||||
super().__setattr__('_sx_lastset_', list(kwargs)[-1])
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
super().__setattr__(name, value)
|
||||
super().__setattr__('_sx_lastset_', name)
|
||||
|
|
Loading…
Reference in New Issue