diff --git a/arcadeutils/filebytes.py b/arcadeutils/filebytes.py index 6d673fb..20e8f46 100644 --- a/arcadeutils/filebytes.py +++ b/arcadeutils/filebytes.py @@ -1,12 +1,18 @@ +import shutil from typing import BinaryIO, Dict, List, Optional, Set, Tuple, Union, overload class FileBytes: + + IO_SIZE: int = 0x8000 + def __init__(self, handle: BinaryIO) -> None: self.__handle: BinaryIO = handle self.__patches: Dict[int, int] = {} self.__copies: List["FileBytes"] = [] self.__unsafe: bool = False + self.__lowest_patch: Optional[int] = None + self.__highest_patch: Optional[int] = None handle.seek(0, 2) self.__filelength: int = handle.tell() @@ -45,7 +51,7 @@ class FileBytes: # Never going to find it anyway. return None - chunksize = max(searchlen * 2, 0x8000) + chunksize = max(searchlen * 2, self.IO_SIZE) startoffset = searchstart data: bytes = self[searchstart:(searchstart + (chunksize * 3))] endoffset = searchstart + len(data) @@ -117,6 +123,8 @@ class FileBytes: # Make a safe copy so that in-memory patches can be changed. myclone = FileBytes(self.__handle) myclone.__patches = {k: v for k, v in self.__patches.items()} + myclone.__lowest_patch = self.__lowest_patch + myclone.__highest_patch = self.__highest_patch myclone.__filelength = self.__filelength myclone.__patchlength = self.__patchlength myclone.__origfilelength = self.__origfilelength @@ -133,7 +141,10 @@ class FileBytes: # Add data to the end of our representation. for off, change in enumerate(data[:]): - self.__patches[self.__patchlength + off] = change + loc = self.__patchlength + off + self.__patches[loc] = change + self.__lowest_patch = min(self.__lowest_patch, loc) if self.__lowest_patch is not None else loc + self.__highest_patch = max(self.__highest_patch, loc + 1) if self.__highest_patch is not None else (loc + 1) self.__patchlength += len(data) @@ -167,20 +178,52 @@ class FileBytes: already.add(inst) self.__gather(already, inst) + def __write_changes(self, handle: BinaryIO) -> None: + locations = sorted(self.__patches.keys()) + keys: Set[int] = set(locations) + handled: Set[int] = set() + for location in locations: + if location in handled: + # Already wrote this in a chunk. + continue + + # Figure out the maximum range for this chunk. + start = location + end = location + 1 + while end in keys: + end += 1 + + # Sum it up + data = bytes(self.__patches[loc] for loc in range(start, end)) + + # Write it + handle.seek(start) + handle.write(data) + + # Mark it complete + handled.update(range(start, end)) + + if keys != handled: + raise Exception("Logic error, failed to write some data!") + def write_changes(self, new_file: Optional[BinaryIO] = None) -> None: if self.__unsafe: raise Exception("Another FileBytes instance representing the same file was written back!") if new_file is not None: # We want to serialize this out to a new file altogether. - for offset in range(0, self.__patchlength, 0x8000): - new_file.write(self[offset:(offset + 0x8000)]) - else: - # We want to update the underlying file to contain this data. - locations = sorted(self.__patches.keys()) - keys: Set[int] = set(locations) - handled: Set[int] = set() + self.__handle.seek(0) + new_file.seek(0) + shutil.copyfileobj(self.__handle, new_file) + # Now, truncate the new file to the right length. + if self.__filelength < self.__origfilelength: + new_file.truncate(self.__filelength) + + # Now, gather up any changes to the file and write them back. + self.__write_changes(new_file) + new_file.flush() + else: # First off, see if we need to truncate the file. if self.__filelength < self.__origfilelength: self.__handle.truncate(self.__filelength) @@ -189,33 +232,13 @@ class FileBytes: raise Exception("Logic error, somehow resized file bigger than it started?") # Now, gather up any changes to the file and write them back. - for location in locations: - if location in handled: - # Already wrote this in a chunk. - continue - - # Figure out the maximum range for this chunk. - start = location - end = location + 1 - while end in keys: - end += 1 - - # Sum it up - data = bytes(self.__patches[loc] for loc in range(start, end)) - - # Write it - self.__handle.seek(start) - self.__handle.write(data) - - # Mark it complete - handled.update(range(start, end)) - - if keys != handled: - raise Exception("Logic error, failed to write some data!") + self.__write_changes(self.__handle) # Now that we've serialized out the data, clean up our own representation. self.__handle.flush() self.__patches.clear() + self.__lowest_patch = None + self.__highest_patch = None self.__filelength = self.__patchlength # Finally, find all other clones of this class and notify them that they're @@ -238,6 +261,8 @@ class FileBytes: inst.__patchlength = self.__patchlength inst.__origfilelength = self.__origfilelength inst.__patches.clear() + self.__lowest_patch = None + self.__highest_patch = None def __slice(self, key: slice) -> Tuple[int, int, int]: # Determine step of slice @@ -312,7 +337,14 @@ class FileBytes: return b"" # Do we have any modifications to the file in this area? - modifications = any(index in self.__patches for index in range(start, stop, step)) + if start >= self.__filelength and stop >= self.__filelength: + modifications = True + elif self.__lowest_patch is None or (start < self.__lowest_patch and stop < self.__lowest_patch): + modifications = False + elif self.__highest_patch is None or (start > self.__highest_patch and stop > self.__highest_patch): + modifications = False + else: + modifications = any(index in self.__patches for index in range(start, stop, step)) # Now see if we can do any fast loading if start < stop and step == 1: @@ -321,13 +353,16 @@ class FileBytes: self.__handle.seek(start) return self.__handle.read(stop - start) else: - # We need to modify at least one of the bytes in this read. - self.__handle.seek(start) - data = [x for x in self.__handle.read(stop - start)] + if start < self.__filelength: + # We need to modify at least one of the bytes in this read. + self.__handle.seek(start) + data = [x for x in self.__handle.read(stop - start)] - # Append any amount of data we need to read past the end of the file. - if len(data) < stop - start: - data = data + ([0] * (stop - len(data))) + # Append any amount of data we need to read past the end of the file. + if len(data) < stop - start: + data.extend([0] * (stop - len(data))) + else: + data = [0] * (stop - start) # Now we have to modify the data with our own overlay. for off in range(start, stop): @@ -341,12 +376,15 @@ class FileBytes: self.__handle.seek(stop + 1) return self.__handle.read(start - stop)[::-1] else: - self.__handle.seek(stop + 1) - data = [x for x in self.__handle.read(start - stop)] + if (stop + 1) < self.__filelength: + self.__handle.seek(stop + 1) + data = [x for x in self.__handle.read(start - stop)] - # Append any amount of data we need to read past the end of the file. - if len(data) < stop - start: - data = data + ([0] * (stop - len(data))) + # Append any amount of data we need to read past the end of the file. + if len(data) < start - stop: + data.extend([0] * (start - len(data))) + else: + data = [0] * (start - stop) # Now we have to modify the data with our own overlay. for index, off in enumerate(range(stop + 1, start + 1)): @@ -393,6 +431,8 @@ class FileBytes: raise IndexError("FileBytes index out of range") self.__patches[key] = val + self.__lowest_patch = min(self.__lowest_patch, key) if self.__lowest_patch is not None else key + self.__highest_patch = max(self.__highest_patch, key + 1) if self.__highest_patch is not None else (key + 1) elif isinstance(key, slice): if not isinstance(val, bytes): @@ -425,6 +465,8 @@ class FileBytes: # Finally, perform the modification. for index, off in enumerate(range(start, stop, step)): self.__patches[off] = val[index] + self.__lowest_patch = min(self.__lowest_patch, off) if self.__lowest_patch is not None else off + self.__highest_patch = max(self.__highest_patch, off + 1) if self.__highest_patch is not None else (off + 1) else: raise NotImplementedError("Not implemented!") diff --git a/setup.py b/setup.py index be7942d..87f81c7 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ with open(os.path.join("arcadeutils", "README.md"), "r", encoding="utf-8") as fh setup( name='arcadeutils', - version='0.1.6', + version='0.1.7', description='Collection of utilities written in Python for working with various arcade binaries.', long_description=long_description, long_description_content_type="text/markdown",