diff --git a/src/meshctrl/session.py b/src/meshctrl/session.py index cec6bd8..63d40a5 100644 --- a/src/meshctrl/session.py +++ b/src/meshctrl/session.py @@ -240,23 +240,28 @@ class Session(object): async def __aexit__(self, exc_t, exc_v, exc_tb): await self.close() - @util._check_socket - async def _send_command(self, data, name, timeout=None): - id = f"meshctrl_{name}_{self._get_command_id()}" + def _generate_response_id(self, name): + responseid = f"meshctrl_{name}_{self._get_command_id()}" # This fixes a very theoretical bug with hash colisions in the case of an infinite int of requests. Now the bug will only happen if there are currently 2**32-1 of the same type of request going out at the same time - while id in self._inflight: - id = f"meshctrl_{name}_{self._get_command_id()}" + while responseid in self._inflight: + responseid = f"meshctrl_{name}_{self._get_command_id()}" + return responseid - self._inflight.add(id) + @util._check_socket + async def _send_command(self, data, name, timeout=None, responseid=None): + if responseid is None: + responseid = self._generate_response_id(name) + + self._inflight.add(responseid) responded = asyncio.Event() response = None async def _(data): - self._inflight.remove(id) + self._inflight.remove(responseid) nonlocal response response = data responded.set() - self._eventer.once(id, _) - await self._message_queue.put(json.dumps(data | {"tag": id, "responseid": id})) + self._eventer.once(responseid, _) + await self._message_queue.put(json.dumps(data | {"tag": responseid, "responseid": responseid})) await asyncio.wait_for(responded.wait(), timeout=timeout) if isinstance(response, Exception): raise response @@ -1459,7 +1464,7 @@ class Session(object): async def run_command(self, nodeids, command, powershell=False, runasuser=False, runasuseronly=False, ignore_output=False, timeout=None): ''' - Run a command on any number of nodes. WARNING: Non namespaced call. Calling this function again before it returns may cause unintended consequences. + Run a command on any number of nodes. WARNING: Non namespaced call on older versions of meshcentral (<1.0.22). Calling this function on those versions again before it returns may cause unintended consequences. Args: nodeids (str|list[str]): Unique ids of nodes on which to run the command @@ -1496,29 +1501,52 @@ class Session(object): if (f"node//{nid}" == id): return nid - result = {n: {"complete": False, "result": [], "command": command} for n in nodeids} + result = None + console_result = {n: {"complete": False, "result": [], "command": command} for n in nodeids} + reply_result = {n: {"complete": False, "result": [], "command": command} for n in nodeids} async def _console(): async for event in self.events({"action": "msg", "type": "console"}): node = match_nodeid(event["nodeid"], nodeids) if node: if event["value"] == "Run commands completed.": - result.setdefault(node, {})["complete"] = True - if all(_["complete"] for key, _ in result.items()): + console_result.setdefault(node, {})["complete"] = True + if all(_["complete"] for key, _ in console_result.items()): break continue elif (event["value"].startswith("Run commands")): continue - result[node]["result"].append(event["value"]) - + console_result[node]["result"].append(event["value"]) + + async def _reply(responseid, data=None): + # Returns True when all results are in, Falsey otherwise + def _parse_event(event): + node = match_nodeid(event["nodeid"], nodeids) + if node: + reply_result.setdefault(node, {})["complete"] = True + reply_result[node]["result"].append(event["result"]) + if all(_["complete"] for key, _ in reply_result.items()): + return True + if data is not None: + if _parse_event(data): + return + async for event in self.events({"action": "msg", "type": "runcommands", "responseid":responseid}): + if _parse_event(event): + break async def __(command, tg, tasks): - data = await self._send_command(command, "run_command", timeout=timeout) + nonlocal result + responseid = self._generate_response_id("run_command") + if not ignore_output: + reply_task = tg.create_task(asyncio.wait_for(_reply(responseid), timeout=timeout)) + console_task = tg.create_task(asyncio.wait_for(_console(), timeout=timeout)) + data = await self._send_command(command, "run_command", timeout=timeout, responseid=responseid) if data.get("type", None) != "runcommands" and data.get("result", "ok").lower() != "ok": raise exceptions.ServerError(data["result"]) elif data.get("type", None) != "runcommands" and data.get("result", "ok").lower() == "ok": + reply_task.cancel() + result = console_result expect_response = False - console_task = tg.create_task(asyncio.wait_for(_console(), timeout=timeout)) if not ignore_output: userid = (await self.user_info())["_id"] for n in nodeids: @@ -1539,20 +1567,9 @@ class Session(object): else: console_task.cancel() elif data.get("type", None) == "runcommands" and not ignore_output: - # Returns True when all results are in, Falsey otherwise - def _parse_event(event): - node = match_nodeid(event["nodeid"], nodeids) - if node: - result.setdefault(node, {})["complete"] = True - result[node]["result"].append(event["result"]) - if all(_["complete"] for key, _ in result.items()): - return True - if data is not None: - if _parse_event(data): - return - async for event in self.events({"action": "msg", "type": "runcommands", "responseid": data["responseid"]}): - if _parse_event(event): - break + result = reply_result + console_task.cancel() + tasks.append(reply_task) tasks = [] async with asyncio.TaskGroup() as tg: @@ -1603,13 +1620,14 @@ class Session(object): if all(_["complete"] for key, _ in result.items()): break async def __(command, tg, tasks): + console_task = tg.create_task(asyncio.wait_for(_console(), timeout=timeout)) data = await self._send_command(command, "run_console_command", timeout=timeout) if data.get("type", None) != "runcommands" and data.get("result", "ok").lower() != "ok": raise exceptions.ServerError(data["result"]) elif data.get("type", None) != "runcommands" and data.get("result", "ok").lower() == "ok": expect_response = False - console_task = tg.create_task(asyncio.wait_for(_console(), timeout=timeout)) + if not ignore_output: userid = (await self.user_info())["_id"] for n in nodeids: