Changed download file APIs so the stream returns at the position where it was passed in

This commit is contained in:
Josiah Baldwin
2024-12-04 13:40:49 -08:00
parent af6c020506
commit 20843dbea7

View File

@@ -1861,17 +1861,20 @@ class Session(object):
:py:class:`~meshctrl.exceptions.FileTransferCancelled`: File transfer cancelled. Info available on the `stats` property :py:class:`~meshctrl.exceptions.FileTransferCancelled`: File transfer cancelled. Info available on the `stats` property
Returns: Returns:
io.IOBase: The stream which has been downloaded into. Cursor will be at the end of the stream. io.IOBase: The stream which has been downloaded into. Cursor will be at the beginning of where the file is downloaded.
''' '''
if target is None: if target is None:
target = io.BytesIO() target = io.BytesIO()
start = target.tell()
if unique_file_tunnel: if unique_file_tunnel:
async with self.file_explorer(nodeid) as files: async with self.file_explorer(nodeid) as files:
await files.download(source, target) await files.download(source, target)
target.seek(start)
return target return target
else: else:
files = await self._cached_file_explorer(nodeid, nodeid) files = await self._cached_file_explorer(nodeid, nodeid)
await files.download(source, target, timeout=timeout) await files.download(source, target, timeout=timeout)
target.seek(start)
return target return target
async def download_file(self, nodeid, source, filepath, unique_file_tunnel=False, timeout=None): async def download_file(self, nodeid, source, filepath, unique_file_tunnel=False, timeout=None):
@@ -1889,10 +1892,10 @@ class Session(object):
:py:class:`~meshctrl.exceptions.FileTransferCancelled`: File transfer cancelled. Info available on the `stats` property :py:class:`~meshctrl.exceptions.FileTransferCancelled`: File transfer cancelled. Info available on the `stats` property
Returns: Returns:
io.IOBase: The stream which has been downloaded into. Cursor will be at the end of the stream. None
''' '''
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
return await self.download(nodeid, source, f, unique_file_tunnel, timeout=timeout) await self.download(nodeid, source, f, unique_file_tunnel, timeout=timeout)
async def _cached_file_explorer(self, nodeid, _id): async def _cached_file_explorer(self, nodeid, _id):
if (_id not in self._file_tunnels or not self._file_tunnels[_id].alive): if (_id not in self._file_tunnels or not self._file_tunnels[_id].alive):