diff --git a/src/meshctrl/files.py b/src/meshctrl/files.py index 19a704e..3c386df 100644 --- a/src/meshctrl/files.py +++ b/src/meshctrl/files.py @@ -27,7 +27,7 @@ class Files(tunnel.Tunnel): "https_proxy": self._session._proxy } self._proxy_handler = urllib.request.ProxyHandler(proxies=proxies) - self._http_opener = urllib.request.build_opener(self._proxy_handler, urllib.request.HTTPSHandler(context=self._ssl_context)) + self._http_opener = urllib.request.build_opener(self._proxy_handler, urllib.request.HTTPSHandler(context=self._session._ssl_context)) def _get_request_id(self): diff --git a/src/meshctrl/session.py b/src/meshctrl/session.py index 3beb0f1..ffb79c0 100644 --- a/src/meshctrl/session.py +++ b/src/meshctrl/session.py @@ -124,15 +124,17 @@ class Session(object): self._message_queue = asyncio.Queue() self._send_task = None self._listen_task = None + self._ssl_context = None + if self._ignore_ssl: + self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self._ssl_context.check_hostname = False + self._ssl_context.verify_mode = ssl.CERT_NONE async def _main_loop(self): try: options = {} - if self._ignore_ssl: - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - options = { "ssl": ssl_context } + if self._ssl_context is not None: + options["ssl"] = self._ssl_context headers = websockets.datastructures.Headers() @@ -215,11 +217,14 @@ class Session(object): return self._command_id async def close(self): - self._main_loop_task.cancel() try: - await self._main_loop_task - except asyncio.CancelledError: - pass + await asyncio.gather(*[tunnel.close() for name, tunnel in self._file_tunnels.items()]) + finally: + self._main_loop_task.cancel() + try: + await self._main_loop_task + except asyncio.CancelledError: + pass @util._check_socket async def __aenter__(self): diff --git a/src/meshctrl/tunnel.py b/src/meshctrl/tunnel.py index 2f5278b..65c2987 100644 --- a/src/meshctrl/tunnel.py +++ b/src/meshctrl/tunnel.py @@ -27,11 +27,6 @@ class Tunnel(object): self._message_queue = asyncio.Queue() self._send_task = None self._listen_task = None - self._ssl_context = None - if self._session._ignore_ssl: - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ssl_context.check_hostname = False - self._ssl_context.verify_mode = ssl.CERT_NONE async def close(self): self._main_loop_task.cancel() @@ -53,8 +48,8 @@ class Tunnel(object): self._authcookie = await self._session._send_command_no_response_id({ "action":"authcookie" }) options = {} - if self._ssl_context is not None: - options = { "ssl": self._ssl_context } + if self._session._ssl_context is not None: + options["ssl"] = self._session._ssl_context if (len(self.node_id.split('/')) != 3): self.node_id = f"node/{self._session._currentDomain or ""}/{self.node_id}" @@ -82,7 +77,6 @@ class Tunnel(object): except* websockets.ConnectionClosed as e: self._socket_open.clear() if not self.auto_reconnect: - self.alive = False raise except* Exception as eg: self.alive = False diff --git a/tests/test_files.py b/tests/test_files.py index e69152a..ed888f9 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -70,7 +70,7 @@ async def test_upload_download(env): else: break - randdata = random.randbytes(2000000) + randdata = random.randbytes(20000000) upfilestream = io.BytesIO(randdata) downfilestream = io.BytesIO() @@ -99,7 +99,7 @@ async def test_upload_download(env): start = time.perf_counter() r = await files.download(f"{pwd}/test", downfilestream, skip_ws_attempt=True, timeout=5) print("\ninfo files_download: {}\n".format(r)) - assert r["result"] == True, "Domnload failed" + assert r["result"] == True, "Download failed" assert r["size"] == len(randdata), "Downloaded wrong number of bytes" print(f"http download time: {time.perf_counter()-start}") @@ -110,7 +110,7 @@ async def test_upload_download(env): start = time.perf_counter() r = await files.download(f"{pwd}/test", downfilestream, skip_http_attempt=True, timeout=5) print("\ninfo files_download: {}\n".format(r)) - assert r["result"] == True, "Domnload failed" + assert r["result"] == True, "Download failed" assert r["size"] == len(randdata), "Downloaded wrong number of bytes" print(f"ws download time: {time.perf_counter()-start}") diff --git a/tests/test_session.py b/tests/test_session.py index f88db24..74d6e88 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -278,7 +278,7 @@ async def test_mesh_device(env): assert await admin_session.move_to_device_group([agent.nodeid], mesh.name, isname=True, timeout=5), "Failed to move mesh to new device group by name" - # For now, this expects no response. If we ever figure out why the server isn't sending console information te us when it should, fix this. + # For now, this expe namects no response. If we ever figure out why the server isn't sending console information te us when it should, fix this. # assert "meshagent" in (await unprivileged_session.run_command(agent.nodeid, "ls", timeout=10))[agent.nodeid]["result"], "ls gave incorrect data" try: await unprivileged_session.run_command(agent.nodeid, "ls", timeout=10) @@ -408,7 +408,7 @@ async def test_session_files(env): break pwd = (await admin_session.run_command(agent.nodeid, "pwd", timeout=10))[agent.nodeid]["result"].strip() - randdata = random.randbytes(2000000) + randdata = random.randbytes(20000000) upfilestream = io.BytesIO(randdata) downfilestream = io.BytesIO() os.makedirs(os.path.join(thisdir, "data"), exist_ok=True)