44import logging
55import threading
66import time
7- from concurrent .futures import ThreadPoolExecutor
87from http .server import ThreadingHTTPServer , BaseHTTPRequestHandler
98from typing import Callable
109
@@ -24,7 +23,7 @@ def __init__(self, name: str, endpoint: str = "http://localhost:9000", max_worke
2423 self ._handlers : dict [str , Callable ] = {}
2524 self ._worker_id : str | None = None
2625 self ._client : MagiCClient | None = None
27- self ._executor = ThreadPoolExecutor ( max_workers = max_workers )
26+ self ._semaphore = threading . Semaphore ( max_workers )
2827
2928 def capability (self , name : str , description : str = "" , est_cost : float = 0.0 ):
3029 """Decorator to register a function as a worker capability."""
@@ -80,8 +79,7 @@ def handle_task(self, task_type: str, input_data: dict) -> dict:
8079
8180 def serve (self , host : str = "0.0.0.0" , port : int = 9000 ):
8281 """Start the worker HTTP server with concurrent task handling."""
83- worker = self
84- executor = self ._executor
82+ worker_ref = self
8583
8684 class Handler (BaseHTTPRequestHandler ):
8785 def do_POST (self ):
@@ -91,7 +89,7 @@ def do_POST(self):
9189 return
9290
9391 length = int (content_length )
94- if length > 10 * 1024 * 1024 : # 10MB limit
92+ if length > 10 * 1024 * 1024 :
9593 self .send_error (413 , "Request too large" )
9694 return
9795
@@ -109,21 +107,29 @@ def do_POST(self):
109107 task_type = payload .get ("task_type" , "" )
110108 logger .info ("Task %s received (type: %s)" , task_id , task_type )
111109
112- # Process task in thread pool
113- future = executor .submit (worker .handle_task , task_type , payload .get ("input" , {}))
114- try :
115- result = future .result (timeout = 300 ) # 5 min timeout
116- response = {
117- "type" : "task.complete" ,
118- "payload" : {"task_id" : task_id , "output" : result , "cost" : 0.0 },
119- }
120- logger .info ("Task %s completed" , task_id )
121- except Exception as e :
110+ acquired = worker_ref ._semaphore .acquire (timeout = 5 )
111+ if not acquired :
122112 response = {
123113 "type" : "task.fail" ,
124- "payload" : {"task_id" : task_id , "error" : {"code" : "handler_error " , "message" : str ( e ) }},
114+ "payload" : {"task_id" : task_id , "error" : {"code" : "overloaded " , "message" : "worker at max capacity" }},
125115 }
126- logger .error ("Task %s failed: %s" , task_id , e )
116+ logger .warning ("Task %s rejected: at max capacity" , task_id )
117+ else :
118+ try :
119+ result = worker_ref .handle_task (task_type , payload .get ("input" , {}))
120+ response = {
121+ "type" : "task.complete" ,
122+ "payload" : {"task_id" : task_id , "output" : result , "cost" : 0.0 },
123+ }
124+ logger .info ("Task %s completed" , task_id )
125+ except Exception as e :
126+ response = {
127+ "type" : "task.fail" ,
128+ "payload" : {"task_id" : task_id , "error" : {"code" : "handler_error" , "message" : str (e )}},
129+ }
130+ logger .error ("Task %s failed: %s" , task_id , e )
131+ finally :
132+ worker_ref ._semaphore .release ()
127133
128134 self .send_response (200 )
129135 self .send_header ("Content-Type" , "application/json" )
@@ -137,7 +143,6 @@ def log_message(self, format, *args):
137143
138144 self ._start_heartbeat ()
139145
140- # Parse port from endpoint URL
141146 parsed = self .endpoint .split (":" )
142147 if len (parsed ) > 2 :
143148 port = int (parsed [- 1 ].split ("/" )[0 ])
@@ -149,4 +154,3 @@ def log_message(self, format, *args):
149154 except KeyboardInterrupt :
150155 logger .info ("Shutting down %s" , self .name )
151156 server .shutdown ()
152- self ._executor .shutdown (wait = True )
0 commit comments