Compare commits

..

1 Commits

Author SHA1 Message Date
Josiah Baldwin
ab2a4c40bc Fixed auto-reconnect for proxy and created tests for auto-reconnect 2024-12-10 13:05:22 -08:00
5 changed files with 77 additions and 26 deletions

View File

@@ -140,13 +140,9 @@ class Session(object):
token = self._token if self._token else b"" token = self._token if self._token else b""
headers['x-meshauth'] = (base64.b64encode(self._user.encode()) + b',' + base64.b64encode(self._password.encode()) + token).decode() headers['x-meshauth'] = (base64.b64encode(self._user.encode()) + b',' + base64.b64encode(self._password.encode()) + token).decode()
if self._proxy:
proxy = Proxy.from_url(self._proxy)
parsed = urllib.parse.urlparse(self.url)
options["sock"] = await proxy.connect(dest_host=parsed.hostname, dest_port=parsed.port)
options["additional_headers"] = headers options["additional_headers"] = headers
async for websocket in websockets.asyncio.client.connect(self.url, process_exception=util._process_websocket_exception, **options): async for websocket in util.proxy_connect(self.url, proxy_url=self._proxy, process_exception=util._process_websocket_exception, **options):
self.alive = True self.alive = True
self._socket_open.set() self._socket_open.set()
try: try:
@@ -154,10 +150,10 @@ class Session(object):
tg.create_task(self._listen_data_task(websocket)) tg.create_task(self._listen_data_task(websocket))
tg.create_task(self._send_data_task(websocket)) tg.create_task(self._send_data_task(websocket))
except* websockets.ConnectionClosed as e: except* websockets.ConnectionClosed as e:
self._socket_open.clear() self._socket_open.clear()
if not self.auto_reconnect: if not self.auto_reconnect:
raise raise
except* Exception as eg: except* Exception as eg:
self.alive = False self.alive = False
self._socket_open.clear() self._socket_open.clear()

View File

@@ -70,12 +70,7 @@ class Tunnel(object):
self.url = self._session.url.replace('/control.ashx', '/meshrelay.ashx?browser=1&p=' + str(self._protocol) + '&nodeid=' + self.node_id + '&id=' + self._tunnel_id + '&auth=' + authcookie["cookie"]) self.url = self._session.url.replace('/control.ashx', '/meshrelay.ashx?browser=1&p=' + str(self._protocol) + '&nodeid=' + self.node_id + '&id=' + self._tunnel_id + '&auth=' + authcookie["cookie"])
if self._session._proxy: async for websocket in util.proxy_connect(self.url, proxy_url=self._session._proxy, process_exception=util._process_websocket_exception, **options):
proxy = Proxy.from_url(self._session._proxy)
parsed = urllib.parse.urlparse(self.url)
options["sock"] = await proxy.connect(dest_host=parsed.hostname, dest_port=parsed.port)
async for websocket in websockets.asyncio.client.connect(self.url, process_exception=util._process_websocket_exception, **options):
self.alive = True self.alive = True
self._socket_open.set() self._socket_open.set()
try: try:

View File

@@ -9,6 +9,9 @@ import re
import websockets import websockets
import ssl import ssl
import functools import functools
import urllib
import python_socks
from python_socks.async_.asyncio import Proxy
from . import exceptions from . import exceptions
def _encode_cookie(o, key): def _encode_cookie(o, key):
@@ -139,7 +142,11 @@ def compare_dict(dict1, dict2):
def _check_socket(f): def _check_socket(f):
@functools.wraps(f) @functools.wraps(f)
async def wrapper(self, *args, **kwargs): async def wrapper(self, *args, **kwargs):
await asyncio.wait_for(self.initialized.wait(), 10) try:
async with asyncio.TaskGroup() as tg:
tg.create_task(asyncio.wait_for(self.initialized.wait(), 10))
tg.create_task(asyncio.wait_for(self._socket_open.wait(), 10))
finally:
if not self.alive and self._main_loop_error is not None: if not self.alive and self._main_loop_error is not None:
raise self._main_loop_error raise self._main_loop_error
elif not self.alive: elif not self.alive:
@@ -149,10 +156,22 @@ def _check_socket(f):
def _process_websocket_exception(exc): def _process_websocket_exception(exc):
tmp = websockets.asyncio.client.process_exception(exc) tmp = websockets.asyncio.client.process_exception(exc)
# SSLVerification error is a subclass of OSError, but doesn't make sense no retry, so we need to handle it separately. # SSLVerification error is a subclass of OSError, but doesn't make sense to retry, so we need to handle it separately.
if isinstance(exc, (ssl.SSLCertVerificationError, TimeoutError)): if isinstance(exc, (ssl.SSLCertVerificationError, TimeoutError)):
return exc return exc
if isinstance(exc, python_socks._errors.ProxyError):
return None
return tmp return tmp
class Proxy(object): class proxy_connect(websockets.asyncio.client.connect):
pass def __init__(self,*args, proxy_url=None, **kwargs):
self.proxy = None
if proxy_url is not None:
self.proxy = Proxy.from_url(proxy_url)
super().__init__(*args, **kwargs)
async def create_connection(self, *args, **kwargs):
if self.proxy is not None:
parsed = urllib.parse.urlparse(self.uri)
self.connection_kwargs["sock"] = await self.proxy.connect(dest_host=parsed.hostname, dest_port=parsed.port)
return await super().create_connection(*args, **kwargs)

View File

@@ -65,7 +65,12 @@ class TestEnvironment(object):
# Destroy the env in case it wasn't killed correctly last time. # Destroy the env in case it wasn't killed correctly last time.
subprocess.check_call(["docker", "compose", "down"], stdout=subprocess.DEVNULL, cwd=thisdir) subprocess.check_call(["docker", "compose", "down"], stdout=subprocess.DEVNULL, cwd=thisdir)
self._subp = _docker_process = subprocess.Popen(["docker", "compose", "up", "--build", "--force-recreate", "--no-deps"], stdout=subprocess.DEVNULL, cwd=thisdir) self._subp = _docker_process = subprocess.Popen(["docker", "compose", "up", "--build", "--force-recreate", "--no-deps"], stdout=subprocess.DEVNULL, cwd=thisdir)
timeout = 30 if not self._wait_for_meshcentral():
self.__exit__(None, None, None)
raise Exception("Failed to create docker instance")
return self
def _wait_for_meshcentral(self, timeout=30):
start = time.time() start = time.time()
while time.time() - start < timeout: while time.time() - start < timeout:
try: try:
@@ -82,10 +87,8 @@ class TestEnvironment(object):
pass pass
time.sleep(1) time.sleep(1)
else: else:
self.__exit__(None, None, None) return False
raise Exception("Failed to create docker instance") return True
return self
def __exit__(self, exc_t, exc_v, exc_tb): def __exit__(self, exc_t, exc_v, exc_tb):
pass pass
@@ -93,6 +96,15 @@ class TestEnvironment(object):
def create_agent(self, meshid): def create_agent(self, meshid):
return Agent(meshid, self.mcurl, self.clienturl, self.dockerurl) return Agent(meshid, self.mcurl, self.clienturl, self.dockerurl)
# Restart our docker instances, to test reconnect code.
def restart_mesh(self):
subprocess.check_call(["docker", "container", "restart", "meshctrl-meshcentral"], stdout=subprocess.DEVNULL, cwd=thisdir)
assert self._wait_for_meshcentral(), "Failed to restart docker instance"
def restart_proxy(self):
subprocess.check_call(["docker", "container", "restart", "meshctrl-squid"], stdout=subprocess.DEVNULL, cwd=thisdir)
def _kill_docker_process(): def _kill_docker_process():
if _docker_process is not None: if _docker_process is not None:
_docker_process.kill() _docker_process.kill()

View File

@@ -31,6 +31,35 @@ async def test_admin(env):
assert len(admin_users) == len(env.users.keys()), "Admin cannot see correct number of users" assert len(admin_users) == len(env.users.keys()), "Admin cannot see correct number of users"
assert len(admin_sessions) == 2, "Admin cannot see correct number of oser sessions" assert len(admin_sessions) == 2, "Admin cannot see correct number of oser sessions"
async def test_auto_reconnect(env):
async with meshctrl.Session(env.mcurl, user="admin", password=env.users["admin"], ignore_ssl=True, auto_reconnect=True) as admin_session:
env.restart_mesh()
await asyncio.sleep(10)
await admin_session.ping(timeout=10)
# As above, but with proxy
async with meshctrl.Session("wss://" + env.dockerurl, user="admin", password=env.users["admin"], ignore_ssl=True, auto_reconnect=True, proxy=env.proxyurl) as admin_session:
env.restart_mesh()
for i in range(3):
try:
await admin_session.ping(timeout=10)
except:
continue
break
else:
raise Exception("Failed to reconnect")
env.restart_proxy()
for i in range(3):
try:
await admin_session.ping(timeout=10)
except* Exception as e:
pass
else:
break
else:
raise Exception("Failed to reconnect")
async def test_users(env): async def test_users(env):
try: try: