A few optimizations to some common cases so writeback is faster as well as getting data of unmodified files.

This commit is contained in:
Jennifer Taylor 2021-10-21 01:32:51 +00:00
parent 9f07dbb55b
commit e02c307004
2 changed files with 87 additions and 45 deletions

View File

@ -1,12 +1,18 @@
import shutil
from typing import BinaryIO, Dict, List, Optional, Set, Tuple, Union, overload
class FileBytes:
IO_SIZE: int = 0x8000
def __init__(self, handle: BinaryIO) -> None:
self.__handle: BinaryIO = handle
self.__patches: Dict[int, int] = {}
self.__copies: List["FileBytes"] = []
self.__unsafe: bool = False
self.__lowest_patch: Optional[int] = None
self.__highest_patch: Optional[int] = None
handle.seek(0, 2)
self.__filelength: int = handle.tell()
@ -45,7 +51,7 @@ class FileBytes:
# Never going to find it anyway.
return None
chunksize = max(searchlen * 2, 0x8000)
chunksize = max(searchlen * 2, self.IO_SIZE)
startoffset = searchstart
data: bytes = self[searchstart:(searchstart + (chunksize * 3))]
endoffset = searchstart + len(data)
@ -117,6 +123,8 @@ class FileBytes:
# Make a safe copy so that in-memory patches can be changed.
myclone = FileBytes(self.__handle)
myclone.__patches = {k: v for k, v in self.__patches.items()}
myclone.__lowest_patch = self.__lowest_patch
myclone.__highest_patch = self.__highest_patch
myclone.__filelength = self.__filelength
myclone.__patchlength = self.__patchlength
myclone.__origfilelength = self.__origfilelength
@ -133,7 +141,10 @@ class FileBytes:
# Add data to the end of our representation.
for off, change in enumerate(data[:]):
self.__patches[self.__patchlength + off] = change
loc = self.__patchlength + off
self.__patches[loc] = change
self.__lowest_patch = min(self.__lowest_patch, loc) if self.__lowest_patch is not None else loc
self.__highest_patch = max(self.__highest_patch, loc + 1) if self.__highest_patch is not None else (loc + 1)
self.__patchlength += len(data)
@ -167,20 +178,52 @@ class FileBytes:
already.add(inst)
self.__gather(already, inst)
def __write_changes(self, handle: BinaryIO) -> None:
locations = sorted(self.__patches.keys())
keys: Set[int] = set(locations)
handled: Set[int] = set()
for location in locations:
if location in handled:
# Already wrote this in a chunk.
continue
# Figure out the maximum range for this chunk.
start = location
end = location + 1
while end in keys:
end += 1
# Sum it up
data = bytes(self.__patches[loc] for loc in range(start, end))
# Write it
handle.seek(start)
handle.write(data)
# Mark it complete
handled.update(range(start, end))
if keys != handled:
raise Exception("Logic error, failed to write some data!")
def write_changes(self, new_file: Optional[BinaryIO] = None) -> None:
if self.__unsafe:
raise Exception("Another FileBytes instance representing the same file was written back!")
if new_file is not None:
# We want to serialize this out to a new file altogether.
for offset in range(0, self.__patchlength, 0x8000):
new_file.write(self[offset:(offset + 0x8000)])
else:
# We want to update the underlying file to contain this data.
locations = sorted(self.__patches.keys())
keys: Set[int] = set(locations)
handled: Set[int] = set()
self.__handle.seek(0)
new_file.seek(0)
shutil.copyfileobj(self.__handle, new_file)
# Now, truncate the new file to the right length.
if self.__filelength < self.__origfilelength:
new_file.truncate(self.__filelength)
# Now, gather up any changes to the file and write them back.
self.__write_changes(new_file)
new_file.flush()
else:
# First off, see if we need to truncate the file.
if self.__filelength < self.__origfilelength:
self.__handle.truncate(self.__filelength)
@ -189,33 +232,13 @@ class FileBytes:
raise Exception("Logic error, somehow resized file bigger than it started?")
# Now, gather up any changes to the file and write them back.
for location in locations:
if location in handled:
# Already wrote this in a chunk.
continue
# Figure out the maximum range for this chunk.
start = location
end = location + 1
while end in keys:
end += 1
# Sum it up
data = bytes(self.__patches[loc] for loc in range(start, end))
# Write it
self.__handle.seek(start)
self.__handle.write(data)
# Mark it complete
handled.update(range(start, end))
if keys != handled:
raise Exception("Logic error, failed to write some data!")
self.__write_changes(self.__handle)
# Now that we've serialized out the data, clean up our own representation.
self.__handle.flush()
self.__patches.clear()
self.__lowest_patch = None
self.__highest_patch = None
self.__filelength = self.__patchlength
# Finally, find all other clones of this class and notify them that they're
@ -238,6 +261,8 @@ class FileBytes:
inst.__patchlength = self.__patchlength
inst.__origfilelength = self.__origfilelength
inst.__patches.clear()
self.__lowest_patch = None
self.__highest_patch = None
def __slice(self, key: slice) -> Tuple[int, int, int]:
# Determine step of slice
@ -312,7 +337,14 @@ class FileBytes:
return b""
# Do we have any modifications to the file in this area?
modifications = any(index in self.__patches for index in range(start, stop, step))
if start >= self.__filelength and stop >= self.__filelength:
modifications = True
elif self.__lowest_patch is None or (start < self.__lowest_patch and stop < self.__lowest_patch):
modifications = False
elif self.__highest_patch is None or (start > self.__highest_patch and stop > self.__highest_patch):
modifications = False
else:
modifications = any(index in self.__patches for index in range(start, stop, step))
# Now see if we can do any fast loading
if start < stop and step == 1:
@ -321,13 +353,16 @@ class FileBytes:
self.__handle.seek(start)
return self.__handle.read(stop - start)
else:
# We need to modify at least one of the bytes in this read.
self.__handle.seek(start)
data = [x for x in self.__handle.read(stop - start)]
if start < self.__filelength:
# We need to modify at least one of the bytes in this read.
self.__handle.seek(start)
data = [x for x in self.__handle.read(stop - start)]
# Append any amount of data we need to read past the end of the file.
if len(data) < stop - start:
data = data + ([0] * (stop - len(data)))
# Append any amount of data we need to read past the end of the file.
if len(data) < stop - start:
data.extend([0] * (stop - len(data)))
else:
data = [0] * (stop - start)
# Now we have to modify the data with our own overlay.
for off in range(start, stop):
@ -341,12 +376,15 @@ class FileBytes:
self.__handle.seek(stop + 1)
return self.__handle.read(start - stop)[::-1]
else:
self.__handle.seek(stop + 1)
data = [x for x in self.__handle.read(start - stop)]
if (stop + 1) < self.__filelength:
self.__handle.seek(stop + 1)
data = [x for x in self.__handle.read(start - stop)]
# Append any amount of data we need to read past the end of the file.
if len(data) < stop - start:
data = data + ([0] * (stop - len(data)))
# Append any amount of data we need to read past the end of the file.
if len(data) < start - stop:
data.extend([0] * (start - len(data)))
else:
data = [0] * (start - stop)
# Now we have to modify the data with our own overlay.
for index, off in enumerate(range(stop + 1, start + 1)):
@ -393,6 +431,8 @@ class FileBytes:
raise IndexError("FileBytes index out of range")
self.__patches[key] = val
self.__lowest_patch = min(self.__lowest_patch, key) if self.__lowest_patch is not None else key
self.__highest_patch = max(self.__highest_patch, key + 1) if self.__highest_patch is not None else (key + 1)
elif isinstance(key, slice):
if not isinstance(val, bytes):
@ -425,6 +465,8 @@ class FileBytes:
# Finally, perform the modification.
for index, off in enumerate(range(start, stop, step)):
self.__patches[off] = val[index]
self.__lowest_patch = min(self.__lowest_patch, off) if self.__lowest_patch is not None else off
self.__highest_patch = max(self.__highest_patch, off + 1) if self.__highest_patch is not None else (off + 1)
else:
raise NotImplementedError("Not implemented!")

View File

@ -8,7 +8,7 @@ with open(os.path.join("arcadeutils", "README.md"), "r", encoding="utf-8") as fh
setup(
name='arcadeutils',
version='0.1.6',
version='0.1.7',
description='Collection of utilities written in Python for working with various arcade binaries.',
long_description=long_description,
long_description_content_type="text/markdown",