11import asyncio
22from contextlib import suppress
33from dataclasses import dataclass , field
4- from typing import Any , List , Optional , Set
4+ from typing import Any , Dict , List , Optional
55
66from arq import ArqRedis
77from arq .connections import RedisSettings , create_pool
8- from arq .constants import result_key_prefix
98from arq .jobs import DeserializationError , Job as ArqJob , JobDef , JobStatus
9+ from arq .utils import timestamp_ms
1010from django .utils import timezone
1111
1212from arq_admin import settings
1313from arq_admin .job import JobInfo
1414
15+ ARQ_PREFIX = 'arq:'
16+ PREFIX_PRIORITY = {prefix : i for i , prefix in enumerate (['job' , 'in-progress' , 'result' ])}
17+
1518
1619@dataclass
1720class QueueStats :
@@ -31,6 +34,10 @@ class QueueStats:
3134class Queue :
3235 redis_settings : RedisSettings
3336 name : str
37+ concurrent_redis_access_sem : asyncio .Semaphore = field (
38+ default_factory = lambda : asyncio .Semaphore (settings .ARQ_MAX_CONNECTIONS ),
39+ )
40+ _cached_job_id_to_status_map : Optional [Dict [str , JobStatus ]] = None
3441 _redis : ArqRedis = field (init = False , default = None ) # type: ignore
3542
3643 async def __aenter__ (self ) -> 'Queue' :
@@ -48,12 +55,12 @@ def from_name(cls, name: str) -> 'Queue':
4855 )
4956
5057 async def get_jobs (self , status : Optional [JobStatus ] = None ) -> List [JobInfo ]:
51- job_ids = await self ._get_job_ids ()
58+ job_id_to_status_map = await self ._get_job_id_to_status_map ()
5259
5360 if status :
54- job_ids_tuple = tuple ( job_ids )
55- statuses = await asyncio . gather ( * [ self . _get_job_status ( job_id ) for job_id in job_ids_tuple ])
56- job_ids = { job_id for ( job_id , job_status ) in zip ( job_ids_tuple , statuses ) if job_status == status }
61+ job_ids = { job_id for ( job_id , job_status ) in job_id_to_status_map . items () if job_status == status }
62+ else :
63+ job_ids = set ( job_id_to_status_map . keys ())
5764
5865 jobs : List [JobInfo ] = await asyncio .gather (* [self .get_job_by_id (job_id ) for job_id in job_ids ])
5966
@@ -66,12 +73,14 @@ async def get_stats(self) -> QueueStats:
6673 port = self .redis_settings .port ,
6774 database = self .redis_settings .database ,
6875 )
76+
6977 try :
70- job_ids = await self ._get_job_ids ()
71- statuses = await asyncio .gather (* [self ._get_job_status (job_id ) for job_id in job_ids ])
78+ job_id_to_status_map = await self ._get_job_id_to_status_map ()
7279 except Exception as ex : # noqa: B902
7380 result .error = str (ex )
7481 else :
82+ statuses = job_id_to_status_map .values ()
83+
7584 result .queued_jobs = len ([status for status in statuses if status == JobStatus .queued ])
7685 result .running_jobs = len ([status for status in statuses if status == JobStatus .in_progress ])
7786 result .deferred_jobs = len ([status for status in statuses if status == JobStatus .deferred ])
@@ -88,10 +97,11 @@ async def get_job_by_id(self, job_id: str) -> JobInfo:
8897
8998 unknown_function_msg = "Can't find job"
9099 base_info = None
91- try :
92- base_info = await arq_job .info ()
93- except DeserializationError :
94- unknown_function_msg = "Unknown, can't deserialize"
100+ async with self .concurrent_redis_access_sem :
101+ try :
102+ base_info = await arq_job .info ()
103+ except DeserializationError :
104+ unknown_function_msg = "Unknown, can't deserialize"
95105
96106 if not base_info :
97107 base_info = JobDef (
@@ -104,7 +114,7 @@ async def get_job_by_id(self, job_id: str) -> JobInfo:
104114 )
105115
106116 job_info = JobInfo .from_base (base_info , job_id )
107- job_info .status = await arq_job . status ( )
117+ job_info .status = await self . _get_job_status ( job_id )
108118
109119 return job_info
110120
@@ -122,17 +132,53 @@ async def abort_job(self, job_id: str) -> Optional[bool]:
122132 return None
123133
124134 async def _get_job_status (self , job_id : str ) -> JobStatus :
135+ if self ._cached_job_id_to_status_map is not None :
136+ return self ._cached_job_id_to_status_map .get (job_id , JobStatus .not_found )
137+
125138 arq_job = ArqJob (
126139 job_id = job_id ,
127140 redis = self ._redis ,
128141 _queue_name = self .name ,
129142 _deserializer = settings .ARQ_DESERIALIZER_BY_QUEUE .get (self .name ),
130143 )
131- return await arq_job .status ()
132-
133- async def _get_job_ids (self ) -> Set [str ]:
134- raw_job_ids = set (await self ._redis .zrangebyscore (self .name , '-inf' , 'inf' ))
135- result_keys = await self ._redis .keys (f'{ result_key_prefix } *' )
136- raw_job_ids |= {key [len (result_key_prefix ):] for key in result_keys }
144+ async with self .concurrent_redis_access_sem :
145+ return await arq_job .status ()
146+
147+ async def _get_job_id_to_status_map (self ) -> Dict [str , JobStatus ]:
148+ if self ._cached_job_id_to_status_map is not None :
149+ return self ._cached_job_id_to_status_map
150+
151+ async with self ._redis .pipeline (transaction = True ) as pipe :
152+ await pipe .keys (f'{ ARQ_PREFIX } *:*' )
153+ await pipe .zrange (self .name , withscores = True , start = 0 , end = - 1 )
154+ all_arq_keys , job_ids_with_scores = await pipe .execute ()
155+
156+ # iter over lists of type [job_id, prefix];
157+ # can't use dict here because we can have multiple keys for one job and need to use the more specific one
158+ job_ids_with_prefixes = (
159+ key .decode ('utf-8' )[len (ARQ_PREFIX ):].split (':' )[::- 1 ] for key in all_arq_keys
160+ )
137161
138- return {job_id .decode ('utf-8' ) if isinstance (job_id , bytes ) else job_id for job_id in raw_job_ids }
162+ job_ids_to_scores = {key [0 ].decode ('utf-8' ): key [1 ] for key in job_ids_with_scores }
163+ job_ids_to_prefixes = dict (sorted (
164+ # not only ensure that we don't get key error but also filter out stuff that's not a client job
165+ ([job_id , prefix ] for job_id , prefix in job_ids_with_prefixes if prefix in PREFIX_PRIORITY ),
166+ # make sure that more specific indices go after less specific ones
167+ key = lambda job_id_with_prefix : PREFIX_PRIORITY [job_id_with_prefix [- 1 ]],
168+ ))
169+
170+ self ._cached_job_id_to_status_map = {
171+ job_id : self ._get_job_status_from_raw_data (prefix , job_ids_to_scores .get (job_id ))
172+ for job_id , prefix in job_ids_to_prefixes .items ()
173+ }
174+
175+ return self ._cached_job_id_to_status_map
176+
177+ def _get_job_status_from_raw_data (self , prefix : str , zscore : Optional [int ]) -> JobStatus : # noqa: CFQ004
178+ if prefix == 'result' :
179+ return JobStatus .complete
180+ if prefix == 'in-progress' and zscore :
181+ return JobStatus .in_progress
182+ if zscore :
183+ return JobStatus .deferred if zscore > timestamp_ms () else JobStatus .queued
184+ return JobStatus .not_found
0 commit comments