From fb8f887259b4c18a9341e1c6d73c2cca753dd974 Mon Sep 17 00:00:00 2001 From: laisee Date: Fri, 6 Jun 2025 16:58:50 +0800 Subject: [PATCH] docs: add license badge --- .pre-commit-config.yaml | 10 ++ README.md | 16 ++- client.py | 302 ++++++---------------------------------- messages.py | 27 ++-- pyproject.toml | 23 +++ tests/test_client.py | 13 +- utils.py | 48 +++++-- 7 files changed, 157 insertions(+), 282 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3db033d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + - repo: https://github.com/PyCQA/bandit + rev: 1.7.6 + hooks: + - id: bandit + args: ["-ll"] diff --git a/README.md b/README.md index 3b2e926..d00bc48 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ [![Build](https://github.com/laisee/client-python-fix/actions/workflows/python-package.yml/badge.svg)](https://github.com/laisee/client-python-fix/actions/workflows/python-package.yml) [![Ruff](https://github.com/laisee/client-python-fix/actions/workflows/rufflint.yml/badge.svg)](https://github.com/laisee/client-python-fix/actions/workflows/rufflint.yml) [![Security Check](https://github.com/laisee/client-python-fix/actions/workflows/security-check.yml/badge.svg)](https://github.com/laisee/client-python-fix/actions/workflows/security-check.yml) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) Python 3.11 @@ -29,7 +30,8 @@ The client is capable of: - **FIX Message Handling:** Constructs and sends FIX messages to the server. - **Secure Communication:** Establishes a secure WebSocket connection to send and receive messages. - **Message Validation:** Validates the presence of necessary fields in messages. -- **Environment Configuration:** Uses environment variables loaded from a .env file for configuration. +- **Environment Configuration:** Uses environment variables loaded from a .env file for configuration and fails fast when required values are missing. +- **Async Operation:** Uses `websockets` with `asyncio` for non-blocking communication and heartbeat handling. ## Prerequisites @@ -61,7 +63,10 @@ The client is capable of: 2. **Install required python libraries:** ```sh - pip install -r requirements.txt + pip install -r requirements.txt + # or install in editable mode using pyproject + pip install -e . + ``` 3. **Generate API Keys** This is done at Power Trade UI under URL 'https://app.power.trade/api-keys' @@ -97,3 +102,10 @@ The client is capable of: Review client actions as it executes logon to server, adds a new order, cancels the order while awaiting response(s). A sleep action allows time to review the new order on system via API or UI before it's cancelled. + +## Development + +Run pre-commit locally to lint and scan before committing: +```sh +pre-commit run --all-files +``` diff --git a/client.py b/client.py index 684f179..f9e9482 100644 --- a/client.py +++ b/client.py @@ -1,24 +1,20 @@ import asyncio import logging import os -import socket import ssl import sys -import threading -import time +from websockets import connect from dotenv import load_dotenv from messages import ( - checkMsg, checkLogonMsg, getMsgCancel, getMsgHeartbeat, getMsgLogon, getMsgNewOrder, - translateFix, ) -from utils import get_attr, get_log_filename +from utils import get_log_filename # Common settings SEPARATOR = "\x01" @@ -47,277 +43,61 @@ # Add file handler to the logger logger.addHandler(file_handler) -# first seqnum taken by LOGON message, this var is incremented for new orders, heartbeat +# first seqnum taken by LOGON message, incremented for new orders and heartbeats seqnum = 2 - -# Event object to signal the heartbeat thread to stop -stop_event = threading.Event() - - -def send_heartbeat(apikey, conn): +async def send_heartbeat(ws, apikey: str) -> None: + """Periodically send FIX heartbeat messages over the websocket.""" global seqnum - seqnum += 1 - try: + init_sleep = int(os.getenv("INIT_SLEEP", 60)) + await asyncio.sleep(init_sleep) + heartbeat_sleep = int(os.getenv("HEARTBEAT_SLEEP", 90)) + while True: + await asyncio.sleep(heartbeat_sleep) + seqnum += 1 msg = getMsgHeartbeat(apikey, seqnum) - conn.sendall(msg) - logger.info(f"Sending Heartbeat Msg: {msg}") - logger.info(f"Sent heartbeat message with Success @ seqnum {seqnum}") - except Exception as e: - logger.error(f"Failed to send Heartbeat message: error was '{e}'") - - -def heartbeat_thread(apikey, conn, stop_event): - try: - INIT_SLEEP = os.getenv( - "INIT_SLEEP", 60 - ) # SLEEP for X seconds while client is starting up, default to 60 seconds - time.sleep(INIT_SLEEP) - HEARTBEAT_SLEEP = int(os.getenv("HEARTBEAT_SLEEP", 90)) # defaults to 90 secs - while not stop_event.is_set(): - # delay start of thread by 20 secs - send_heartbeat(apikey, conn) - time.sleep( - HEARTBEAT_SLEEP - ) # Send heartbeat every `HEARTBEAT_SLEEP` seconds - except Exception as e: - print(f"Heartbeat thread exception: {e}") + await ws.send(msg) + logger.info("Sent heartbeat message") -async def main(server: str, port: int, apikey: str): +async def main(server: str, port: int, apikey: str) -> None: + """Connect to the FIX endpoint and submit a sample order.""" global seqnum - # - # Trade test values - # N.B. Not designed for PRODUCTION trading - # - RESP_SENDER = "PT-OE" - SYMBOL: str = "ETH-USD" - PRICE: float = 2508.08 #3090.00 + randint(1, 8) - QUANTITY: float = .1 - - # Define server address w/ port - server_addr = f"{server}:{port}" - logger.info(f"server: {server_addr}") - - # Create context for the TLS connection - context = ssl.create_default_context() - - # Wrap the socket with SSL - context.load_verify_locations(cafile=os.getenv("CERTFILE_LOCATION", "cert.crt")) - logger.info("Context created") + SYMBOL = "ETH-USD" + PRICE = 2508.08 + QUANTITY = 0.1 - context.check_hostname = True - context.verify_mode = ssl.CERT_REQUIRED + uri = f"wss://{server}:{port}" + ssl_context = ssl.create_default_context() + ssl_context.load_verify_locations(cafile=os.getenv("CERTFILE_LOCATION", "cert.crt")) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - # wait up to X secs for receiving responses - logger.info( f"Assigning WAIT for Fix response messages of {os.getenv('MSG_RESPONSE_WAIT', 5)} seconds") - sock.settimeout(int(os.getenv("MSG_RESPONSE_WAIT", 5))) - - print(f"connecting to {server} on port {port} ...") - sock.connect((server, port)) - - conn = context.wrap_socket(sock, server_hostname=server) - - try: - print("Handshaking Fix SSL/TLS connection ...") - conn.do_handshake() - - # Check Fix API connection with Logon message + async with connect(uri, ssl=ssl_context) as ws: msg = getMsgLogon(apikey) - error_msg = "" - try: - print(f"Sending Logon request {msg} to server {server} ...") - logger.debug(f"Sending Logon msg {msg} to server {server} ...") - # send Fix Logon message - conn.sendall(msg) - - print(f"Reading Logon response from server {server} ...") - logger.debug(f"Reading Logon response from server {server} ...") - response = conn.recv(1024) + await ws.send(msg) + response = await ws.recv() + valid, error_msg = checkLogonMsg(response) + if not valid: + logger.error(f"Invalid Logon response: {error_msg}") + return - print(f"Checking Logon response from server {response} ...") - valid, error_msg = checkLogonMsg(response) - if valid: - logger.info("Received valid Logon response") - # Start heartbeat thread - threading.Thread( - target=heartbeat_thread, - args=( - apikey, - conn, - stop_event, - ), - ).start() - else: - logger.error(f"Invalid Logon response: error was '{error_msg}'") - sys.exit(1) + heartbeat_task = asyncio.create_task(send_heartbeat(ws, apikey)) - clOrdID, msg = getMsgNewOrder(SYMBOL, PRICE, QUANTITY, apikey, seqnum) - decoded_msg = msg.decode("utf-8") - print( - "Sending new order [%s] with order details: {%s}" - % (clOrdID, decoded_msg) - ) - logger.debug( - "Sending new order [%s] with order details: {%s}" - % (clOrdID, decoded_msg) - ) - conn.sendall(msg) + clOrdID, order_msg = getMsgNewOrder(SYMBOL, PRICE, QUANTITY, apikey, seqnum) + await ws.send(order_msg) + logger.info("Sent new order") - print("Reading New Order response from server ...") - response = conn.recv(1024) + resp = await ws.recv() + logger.info(f"Order response: {resp}") - logger.debug(f"Received(decoded): {response.decode('utf-8')}") - valid = checkMsg(response, RESP_SENDER, apikey) - ( - print("Received valid New Order response") - if valid - else print(f"Received invalid New Order response -> {response}") - ) + cancel_id = 11111 + seqnum += 1 + _, cancel_msg = getMsgCancel(clOrdID, cancel_id, SYMBOL, apikey, seqnum) + await ws.send(cancel_msg) + logger.info("Sent cancel request") - # - # iterate few times with sleep to allow trading messages from Limit Order to arrive - # - count = 0 - POLL_SLEEP = int( - os.getenv("POLL_SLEEP", 5) - ) # seconds to sleep between iterations - POLL_LIMIT = int(os.getenv("POLL_LIMIT", 10)) # iteration count + await asyncio.sleep(int(os.getenv("FINAL_SLEEP", 20))) + heartbeat_task.cancel() - logger.info( - f"Waiting for New Order [{clOrdID}] confirmation response from server [{count}] ..." - ) - - while count < POLL_LIMIT: - time.sleep(POLL_SLEEP) - try: - logger.info("Waiting for new message ...") - response = conn.recv(1024) - # response = await asyncio.get_event_loop().sock_recv(conn, 1024) - msg_str = response.decode("utf-8").replace(SEPARATOR, VERTLINE) - if msg_str is not None: - logger.info(f"Received(decoded):\n {msg_str}") - msg_list = msg_str.split("8=FIX.4.4") - for i, msg in enumerate(msg_list): - logger.debug( - "Recd msg: Ord '%s' Type [%s] Sts [%s]" - % ( - get_attr(msg_str, "11"), - translateFix("35", get_attr(msg_str, "35")), - translateFix("39", get_attr(msg_str, "39")), - ) - ) - if ( - get_attr(msg, "35") == "8" - and translateFix("39", get_attr(msg, "39")) == "New" - ): - logger.info( - "Exit Wait loop for order confirmation as received order status == 'New'" - ) - count = POLL_LIMIT - break - except Exception as e: - logger.error("Error while waiting for new message -> %s" % e) - count += 1 - - # setup cancel order to remove new order added above - cancelOrderID = 11111 # clOrdID - - print(f"Sleep {POLL_SLEEP*5} secs before starting to Cancel orders") - logger.info("Sleep before starting to Cancel orders") - time.sleep(POLL_SLEEP * 5) - # - # Cancel Order can be done if the New Limit Order above is not filled - # - logger.debug("Building Cancel Order Message for order %s" % cancelOrderID) - seqnum += 1 - now, msg = getMsgCancel(clOrdID, cancelOrderID, SYMBOL, apikey, seqnum) - logger.debug( - "Sending Cancel Order Message %s for order %s with Seqnum {seqnum}" - % (msg, cancelOrderID) - ) - conn.sendall(msg) - - # - # Await response from order cancel message - # - count = 0 - POLL = True - while POLL and count < POLL_LIMIT: - time.sleep(POLL_SLEEP) - logger.debug("Awaiting Cancel order response from server ...") - response = conn.recv(1024) - msg = response.decode("utf-8").replace(SEPARATOR, VERTLINE) - logger.debug( - "Received msg from server with type [%s] status [%s]" - % ( - translateFix("35", get_attr(msg, "35")), - translateFix("39", get_attr(msg, "39")), - ) - ) - - # - # was received message a 'heartbeat' [Msg Type = '0'] - # - if get_attr(msg, "35") == "0": - logger.info("Heartbeat msg received ...") - # - # received message an 'execution report' [Msg Type = '8'] - # - elif ( - get_attr(msg, "35") == "8" - and translateFix("39", get_attr(msg, "39")) == "Cancelled" - ): - logger.info( - "Received Order Cancel response with order status == 'Cancelled'" - ) - POLL = False - # - # Check status of the order i.e. '2' for Filled, '8' for Rejected - elif ( - get_attr(msg, "35") == "9" - and translateFix("39", get_attr(msg, "39")) == "Rejected" - ): - logger.info( - "Received Order Cancel response with order status == 'Rejected'" - ) - POLL = False - else: - logger.debug( - f"Received(decoded): {response.decode('utf-8').replace(SEPARATOR,VERTLINE)}" - ) - logger.debug( - "Recd msg with type [%s] status [%s] for order %s" - % ( - translateFix("35", get_attr(msg, "35")), - translateFix("39", get_attr(msg, "39")), - cancelOrderID, - ) - ) - count += 1 - except socket.timeout: - wait_time = os.getenv("MSG_RESPONSE_WAIT", 5) - logger.info(f"Receive operation timed out after {wait_time} seconds.") - except Exception as e: - logger.error(f"Error while processing send/receive Fix messages: {e}") - - except Exception as e: - logger.error(f"Failed to make Fix connection and send Order message: {e}") - finally: - # - # Allow 'FINAL_SLEEP' seconds to pass so we can check account balance / possition changes / open orders before closing connection which will remove open orders - # - FINAL_SLEEP = int(os.getenv("FINAL_SLEEP", 20)) - logger.info(f"\nWaiting {FINAL_SLEEP} secs to close connection") - stop_event.set() # Signal the heartbeat thread to stop - time.sleep(FINAL_SLEEP) - sock.close() - conn.close() - logger.info( - "\n**************************************************************************\n" - ) if __name__ == "__main__": diff --git a/messages.py b/messages.py index a8a1c5b..ba282d2 100644 --- a/messages.py +++ b/messages.py @@ -9,8 +9,9 @@ VERTLINE = "|" -def checkLogonMsg(msg: bytes): - status = None +def checkLogonMsg(msg: bytes) -> tuple[bool | None, str]: + """Validate a FIX logon response message.""" + status: bool | None = None error_msg = "" assert msg is not None, "error - message cannot be None or empty string" fields = msg.decode("utf-8").replace(SEPARATOR, VERTLINE).split(VERTLINE) @@ -34,7 +35,8 @@ def checkLogonMsg(msg: bytes): return status, error_msg -def checkMsg(msg: bytes, sender: str, target: str): +def checkMsg(msg: bytes, sender: str, target: str) -> bool: + """Check a generic FIX message contains sender and target IDs.""" count = 0 assert msg is not None, "error - message cannot be None or empty string" @@ -54,7 +56,8 @@ def checkMsg(msg: bytes, sender: str, target: str): return True if count == 2 else False -def getMsgHeartbeat(apikey: str, seqnum: int): +def getMsgHeartbeat(apikey: str, seqnum: int) -> bytes: + """Return a FIX heartbeat message.""" msg = sfx.FixMessage() now = int(time.time()) @@ -73,7 +76,8 @@ def getMsgHeartbeat(apikey: str, seqnum: int): return msg.encode() -def getMsgLogon(apikey: str): +def getMsgLogon(apikey: str) -> bytes: + """Return a FIX logon message.""" msg = sfx.FixMessage() now = int(time.time()) @@ -98,7 +102,8 @@ def getMsgLogon(apikey: str): def getMsgNewOrder( symbol: str, price: float, quantity: float, apikey: str, seqnum: int = 2 -): +) -> tuple[int, bytes]: + """Return FIX message bytes for a new single order.""" msg = sfx.FixMessage() now = int(time.time()) @@ -128,7 +133,10 @@ def getMsgNewOrder( return now, msg.encode() -def getMsgRFQ(symbol: str, price: float, quantity: float, apikey: str, seqnum: int = 2): +def getMsgRFQ( + symbol: str, price: float, quantity: float, apikey: str, seqnum: int = 2 +) -> tuple[int, bytes]: + """Return a FIX request-for-quote message.""" msg = sfx.FixMessage() now = int(time.time()) @@ -169,7 +177,8 @@ def getMsgCancel( apikey: str, seqnum: int, side: int = 1, -): +) -> tuple[int, bytes]: + """Return a FIX order cancel request message.""" assert orderID is not None, "Error - orderId must not be empty or None" @@ -196,7 +205,7 @@ def getMsgCancel( return now, msg.encode() -def translateFix(key, value): +def translateFix(key: str, value: str) -> str: trans = value if key == "35": if value == "0": diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..476a9a9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "client-python-fix" +version = "0.1.0" +description = "Power.Trade FIX client example" +readme = "README.md" +requires-python = ">=3.11" +authors = [ + { name = "Power.Trade" } +] +license = { file = "LICENSE" } + +[project.dependencies] +python-dotenv = "1.0.1" +PyJWT = "2.9.0" +simplefix = "1.0.17" +websockets = "11.0.3" + +[tool.pytest.ini_options] +addopts = "-q" + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/tests/test_client.py b/tests/test_client.py index 51dfff7..c3efbb1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ from unittest import TestCase +import os import pytest from messages import ( @@ -9,6 +10,7 @@ getMsgNewOrder, getMsgRFQ, ) +from utils import generateJWT APIKEY = "DUMMMY" MSG_LOGON = b"8=FIX.4.2|9=112|35=A|34=1|49=CLIENT|52=20240729-14:35:00.000|56=SERVER|98=0|108=30|10=072|" @@ -33,7 +35,7 @@ def test_checkLogonMsg(self): def test_getMsgLOGON(self): assert getMsgLogon is not None - with pytest.raises(ValueError, match="Error"): + with pytest.raises(ValueError, match="API_SECRET"): msg = getMsgLogon(APIKEY) assert msg is not None @@ -44,6 +46,14 @@ def test_getMsgNewOrder(self): ) assert msg is not None assert len(msg) > 0 + msg_str = msg[1].decode("utf-8") if isinstance(msg, tuple) else msg.decode("utf-8") + assert "35=D" in msg_str + + def test_generateJWT_requires_env(self): + os.environ.pop("API_SECRET", None) + os.environ.pop("API_URI", None) + with pytest.raises(ValueError): + generateJWT(APIKEY, 0) def test_getMsgRFQ(self): assert getMsgRFQ is not None @@ -61,3 +71,4 @@ def test_getMsgCancel(self): seqnum=SEQNUM, ) assert msg is not None + diff --git a/utils.py b/utils.py index d8d1fb8..dcc0ae0 100644 --- a/utils.py +++ b/utils.py @@ -4,15 +4,32 @@ import jwt -def generateJWT(apikey: str, now): +def generateJWT(apikey: str, now: int) -> str: + """Generate a signed JWT token using environment configuration. + + Parameters + ---------- + apikey: + API key used as token subject. + now: + Current epoch time. + + Returns + ------- + str + Encoded JWT token string. + """ - # # retrieve values from env variables - # - API_SECRET = os.getenv("API_SECRET", "DUMMY") - URI = os.getenv("API_URI", "DUMMY") + API_SECRET = os.getenv("API_SECRET") + URI = os.getenv("API_URI") DURATION = int(os.getenv("JWT_DURATION", 86400000)) + if not API_SECRET: + raise ValueError("API_SECRET environment variable is required") + if not URI: + raise ValueError("API_URI environment variable is required") + payload = { "client": "api", "uri": URI, @@ -30,7 +47,19 @@ def generateJWT(apikey: str, now): ) -def get_log_filename(prefix): +def get_log_filename(prefix: str) -> str: + """Create a timestamped log filename. + + Parameters + ---------- + prefix: + Prefix for the log file name. + + Returns + ------- + str + Filename with UTC timestamp. + """ # Get the current UTC datetime now_utc = datetime.utcnow() # Format the datetime into a string @@ -40,7 +69,7 @@ def get_log_filename(prefix): return filename -def get_attr(fix_message, key): +def get_attr(fix_message: str, key: str) -> str | None: """ Extracts the value for a given key from a FIX message. @@ -51,7 +80,7 @@ def get_attr(fix_message, key): return get_attrs(fix_message).get(key) -def get_attrs(fix_message): +def get_attrs(fix_message: str) -> dict[str, str]: """ Parses a FIX message with '|' as the separator into a dictionary of attributes. @@ -71,7 +100,8 @@ def get_attrs(fix_message): return attributes -def format_epoch_time(epoch_time): +def format_epoch_time(epoch_time: int) -> str: + """Return FIX timestamp string for a given epoch time.""" # Convert epoch time to datetime object dt = datetime.fromtimestamp(epoch_time, UTC)