diff --git a/src/meshctrl/exceptions.py b/src/meshctrl/exceptions.py index 43ae378..9ac3af2 100644 --- a/src/meshctrl/exceptions.py +++ b/src/meshctrl/exceptions.py @@ -2,7 +2,9 @@ class MeshCtrlError(Exception): """ Base class for Meshctrl errors """ - pass + def __init__(self, message, *args, **kwargs): + self.message = message + super().__init__(message, *args, **kwargs) class ServerError(MeshCtrlError): """ @@ -25,6 +27,7 @@ class FileTransferError(MeshCtrlError): """ def __init__(self, message, stats): self.stats = stats + super().__init__(message) class FileTransferCancelled(FileTransferError): """ diff --git a/src/meshctrl/util.py b/src/meshctrl/util.py index fe36fc5..b245d59 100644 --- a/src/meshctrl/util.py +++ b/src/meshctrl/util.py @@ -140,17 +140,20 @@ def compare_dict(dict1, dict2): return False def _check_socket(f): + async def _check_errs(self): + if not self.alive and self._main_loop_error is not None: + raise self._main_loop_error + elif not self.alive and self.initialized.is_set(): + raise exceptions.SocketError("Socket Closed") + @functools.wraps(f) async def wrapper(self, *args, **kwargs): try: - async with asyncio.TaskGroup() as tg: - tg.create_task(asyncio.wait_for(self.initialized.wait(), 10)) - tg.create_task(asyncio.wait_for(self._socket_open.wait(), 10)) + await asyncio.wait_for(self.initialized.wait(), 10) + await _check_errs(self) + await asyncio.wait_for(self._socket_open.wait(), 10) finally: - if not self.alive and self._main_loop_error is not None: - raise self._main_loop_error - elif not self.alive and self.initialized.is_set(): - raise exceptions.SocketError("Socket Closed") + await _check_errs(self) return await f(self, *args, **kwargs) return wrapper diff --git a/tests/test_session.py b/tests/test_session.py index 74d6e88..e7cec28 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,6 +5,8 @@ import meshctrl import requests import random import io +import traceback +import time thisdir = os.path.dirname(os.path.realpath(__file__)) async def test_admin(env): @@ -77,6 +79,17 @@ async def test_users(env): pass else: raise Exception("Connected with no password") + + start = time.time() + try: + async with meshctrl.Session(env.mcurl, user="admin", password="The wrong password", ignore_ssl=True) as admin_session: + pass + except* meshctrl.exceptions.ServerError as eg: + assert str(eg.exceptions[0]) == "Invalid Auth" or eg.exceptions[0].message == "Invalid Auth", "Didn't get invalid auth message" + assert time.time() - start < 10, "Invalid auth wasn't raised until after timeout" + pass + else: + raise Exception("Connected with bad password") async with meshctrl.Session(env.mcurl+"/", user="admin", password=env.users["admin"], ignore_ssl=True) as admin_session,\ meshctrl.Session(env.mcurl, user="privileged", password=env.users["privileged"], ignore_ssl=True) as privileged_session,\ meshctrl.Session(env.mcurl, user="unprivileged", password=env.users["unprivileged"], ignore_ssl=True) as unprivileged_session: