From f9fcc7eebe461bb4d769db22e0574964061a92a7 Mon Sep 17 00:00:00 2001 From: blee Date: Fri, 1 Dec 2023 15:06:44 -0500 Subject: [PATCH 1/4] replace deprecated low-level socket usage w/ R/W streams --- app/contacts/contact_tcp.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/app/contacts/contact_tcp.py b/app/contacts/contact_tcp.py index d6ff0b754..27ae4ebaa 100644 --- a/app/contacts/contact_tcp.py +++ b/app/contacts/contact_tcp.py @@ -7,6 +7,7 @@ from app.utility.base_world import BaseWorld from plugins.manx.app.c_session import Session +from plugins.manx.app.c_connection import Connection class Contact(BaseWorld): @@ -60,7 +61,8 @@ async def refresh(self): session = self.sessions[index] try: - session.connection.send(str.encode(' ')) + session.connection.writer.write(str.encode(' ')) + await session.connection.writer.drain() except socket.error: self.log.debug('Error occurred when refreshing session %s. Removing from session pool.', session.id) del self.sessions[index] @@ -73,20 +75,20 @@ async def accept(self, reader, writer): except Exception as e: self.log.debug('Handshake failed: %s' % e) return - connection = writer.get_extra_info('socket') profile['executors'] = [e for e in profile['executors'].split(',') if e] profile['contact'] = 'tcp' agent, _ = await self.services.get('contact_svc').handle_heartbeat(**profile) - new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=connection) + new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=Connection(reader, writer)) self.sessions.append(new_session) await self.send(new_session.id, agent.paw, timeout=5) async def send(self, session_id: int, cmd: str, timeout: int = 60) -> Tuple[int, str, str, str]: try: conn = next(i.connection for i in self.sessions if i.id == int(session_id)) - conn.send(str.encode(' ')) + conn.writer.write(str.encode(' ')) time.sleep(0.01) - conn.send(str.encode('%s\n' % cmd)) + conn.writer.write(str.encode('%s\n' % cmd)) + await conn.writer.drain() response = await self._attempt_connection(session_id, conn, timeout=timeout) response = json.loads(response) return response['status'], response['pwd'], response['response'], response.get('agent_reported_time', '') @@ -106,7 +108,7 @@ async def _attempt_connection(self, session_id, connection, timeout): time.sleep(0.1) # initial wait for fast operations. while True: try: - part = connection.recv(buffer) + part = await connection.reader.read(buffer) data += part if len(part) < buffer: break From 599419dc91a1784d78be36fdb1647243f904d83c Mon Sep 17 00:00:00 2001 From: blee Date: Fri, 1 Dec 2023 16:11:26 -0500 Subject: [PATCH 2/4] move lower-level r/w functionality to c_connection.py --- app/contacts/contact_tcp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/app/contacts/contact_tcp.py b/app/contacts/contact_tcp.py index 27ae4ebaa..65cf66a15 100644 --- a/app/contacts/contact_tcp.py +++ b/app/contacts/contact_tcp.py @@ -61,8 +61,7 @@ async def refresh(self): session = self.sessions[index] try: - session.connection.writer.write(str.encode(' ')) - await session.connection.writer.drain() + await session.connection.send(str.encode(' ')) except socket.error: self.log.debug('Error occurred when refreshing session %s. Removing from session pool.', session.id) del self.sessions[index] @@ -85,10 +84,9 @@ async def accept(self, reader, writer): async def send(self, session_id: int, cmd: str, timeout: int = 60) -> Tuple[int, str, str, str]: try: conn = next(i.connection for i in self.sessions if i.id == int(session_id)) - conn.writer.write(str.encode(' ')) + await conn.send(str.encode(' ')) time.sleep(0.01) - conn.writer.write(str.encode('%s\n' % cmd)) - await conn.writer.drain() + await conn.send(str.encode('%s\n' % cmd)) response = await self._attempt_connection(session_id, conn, timeout=timeout) response = json.loads(response) return response['status'], response['pwd'], response['response'], response.get('agent_reported_time', '') @@ -108,7 +106,7 @@ async def _attempt_connection(self, session_id, connection, timeout): time.sleep(0.1) # initial wait for fast operations. while True: try: - part = await connection.reader.read(buffer) + part = await connection.recv(buffer) data += part if len(part) < buffer: break From 60412a7815126a07e78d01dce481a457d7995b49 Mon Sep 17 00:00:00 2001 From: blee Date: Fri, 8 Dec 2023 19:56:15 -0500 Subject: [PATCH 3/4] fix contact tcp tests --- tests/contacts/test_contact_tcp.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/contacts/test_contact_tcp.py b/tests/contacts/test_contact_tcp.py index cb420a64f..37314ac7f 100644 --- a/tests/contacts/test_contact_tcp.py +++ b/tests/contacts/test_contact_tcp.py @@ -9,28 +9,34 @@ class TestTcpSessionHandler: - def test_refresh_with_socket_errors(self, event_loop): + def test_refresh_with_socket_errors(self, event_loop, async_return): handler = TcpSessionHandler(services=None, log=logger) session_with_socket_error = mock.Mock() session_with_socket_error.connection.send.side_effect = socket.error() + standard_session = mock.Mock() + standard_session.connection.send.return_value = async_return(True) + handler.sessions = [ session_with_socket_error, session_with_socket_error, - mock.Mock() + standard_session ] event_loop.run_until_complete(handler.refresh()) assert len(handler.sessions) == 1 assert all(x is not session_with_socket_error for x in handler.sessions) - def test_refresh_without_socket_errors(self, event_loop): + def test_refresh_without_socket_errors(self, event_loop, async_return): + standard_session = mock.Mock() + standard_session.connection.send.return_value = async_return(True) + handler = TcpSessionHandler(services=None, log=logger) handler.sessions = [ - mock.Mock(), - mock.Mock(), - mock.Mock() + standard_session, + standard_session, + standard_session ] event_loop.run_until_complete(handler.refresh()) From 3fe85eba76cde6e23db5d09b7c7fb663918eda51 Mon Sep 17 00:00:00 2001 From: blee Date: Thu, 21 Dec 2023 10:18:20 -0500 Subject: [PATCH 4/4] add additional unit tests for send method --- tests/contacts/test_contact_tcp.py | 56 ++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/contacts/test_contact_tcp.py b/tests/contacts/test_contact_tcp.py index 37314ac7f..c96b407e6 100644 --- a/tests/contacts/test_contact_tcp.py +++ b/tests/contacts/test_contact_tcp.py @@ -1,8 +1,11 @@ +import json import logging import socket from unittest import mock +from tests.conftest import async_return from app.contacts.contact_tcp import TcpSessionHandler +from plugins.manx.app.c_session import Session logger = logging.getLogger(__name__) @@ -41,3 +44,56 @@ def test_refresh_without_socket_errors(self, event_loop, async_return): event_loop.run_until_complete(handler.refresh()) assert len(handler.sessions) == 3 + + async def test_send_with_connection_errors(self, async_return): + test_session_id = 123 + test_paw = 'paw123' + test_cmd = 'whoami' + test_exception = Exception('Exception Raised') + + mock_connection = mock.Mock() + mock_connection.send.return_value = async_return(True) + standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection) + + handler = TcpSessionHandler(services=None, log=logger) + handler.sessions = [ + standard_session, + standard_session + ] + + handler._attempt_connection = mock.Mock() + handler._attempt_connection.side_effect = test_exception + response = await handler.send(test_session_id, test_cmd) + expected_response = (1, '~$ ', str(test_exception), '') + + assert len(handler.sessions) == 2 + assert response == expected_response + + async def test_send_without_connection_error(self, async_return): + test_session_id = 123 + test_paw = 'paw123' + test_cmd = 'whoami' + json_response = { + 'status': 0, + 'pwd': '/test', + 'response': '' + } + expected_response = (json_response['status'], json_response['pwd'], json_response['response'], + json_response.get('agent_reported_time', '')) + + mock_connection = mock.Mock() + mock_connection.send.return_value = async_return(True) + standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection) + + handler = TcpSessionHandler(services=None, log=logger) + handler.sessions = [ + standard_session, + standard_session + ] + + handler._attempt_connection = mock.Mock() + handler._attempt_connection.return_value = async_return(json.dumps(json_response)) + received_response = await handler.send(test_session_id, test_cmd) + + assert len(handler.sessions) == 2 + assert received_response == expected_response