Skip to content

Commit e757685

Browse files
Avoid per-job status requests on list views and limit concurrency (#9)
Co-authored-by: Slava Skvortsov <29122694+SlavaSkvortsov@users.noreply.github.com>
1 parent c3249db commit e757685

4 files changed

Lines changed: 87 additions & 22 deletions

File tree

arq_admin/queue.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import asyncio
22
from contextlib import suppress
33
from dataclasses import dataclass, field
4-
from typing import Any, List, Optional, Set
4+
from typing import Any, Dict, List, Optional
55

66
from arq import ArqRedis
77
from arq.connections import RedisSettings, create_pool
8-
from arq.constants import result_key_prefix
98
from arq.jobs import DeserializationError, Job as ArqJob, JobDef, JobStatus
9+
from arq.utils import timestamp_ms
1010
from django.utils import timezone
1111

1212
from arq_admin import settings
1313
from arq_admin.job import JobInfo
1414

15+
ARQ_PREFIX = 'arq:'
16+
PREFIX_PRIORITY = {prefix: i for i, prefix in enumerate(['job', 'in-progress', 'result'])}
17+
1518

1619
@dataclass
1720
class QueueStats:
@@ -31,6 +34,10 @@ class QueueStats:
3134
class Queue:
3235
redis_settings: RedisSettings
3336
name: str
37+
concurrent_redis_access_sem: asyncio.Semaphore = field(
38+
default_factory=lambda: asyncio.Semaphore(settings.ARQ_MAX_CONNECTIONS),
39+
)
40+
_cached_job_id_to_status_map: Optional[Dict[str, JobStatus]] = None
3441
_redis: ArqRedis = field(init=False, default=None) # type: ignore
3542

3643
async def __aenter__(self) -> 'Queue':
@@ -48,12 +55,12 @@ def from_name(cls, name: str) -> 'Queue':
4855
)
4956

5057
async def get_jobs(self, status: Optional[JobStatus] = None) -> List[JobInfo]:
51-
job_ids = await self._get_job_ids()
58+
job_id_to_status_map = await self._get_job_id_to_status_map()
5259

5360
if status:
54-
job_ids_tuple = tuple(job_ids)
55-
statuses = await asyncio.gather(*[self._get_job_status(job_id) for job_id in job_ids_tuple])
56-
job_ids = {job_id for (job_id, job_status) in zip(job_ids_tuple, statuses) if job_status == status}
61+
job_ids = {job_id for (job_id, job_status) in job_id_to_status_map.items() if job_status == status}
62+
else:
63+
job_ids = set(job_id_to_status_map.keys())
5764

5865
jobs: List[JobInfo] = await asyncio.gather(*[self.get_job_by_id(job_id) for job_id in job_ids])
5966

@@ -66,12 +73,14 @@ async def get_stats(self) -> QueueStats:
6673
port=self.redis_settings.port,
6774
database=self.redis_settings.database,
6875
)
76+
6977
try:
70-
job_ids = await self._get_job_ids()
71-
statuses = await asyncio.gather(*[self._get_job_status(job_id) for job_id in job_ids])
78+
job_id_to_status_map = await self._get_job_id_to_status_map()
7279
except Exception as ex: # noqa: B902
7380
result.error = str(ex)
7481
else:
82+
statuses = job_id_to_status_map.values()
83+
7584
result.queued_jobs = len([status for status in statuses if status == JobStatus.queued])
7685
result.running_jobs = len([status for status in statuses if status == JobStatus.in_progress])
7786
result.deferred_jobs = len([status for status in statuses if status == JobStatus.deferred])
@@ -88,10 +97,11 @@ async def get_job_by_id(self, job_id: str) -> JobInfo:
8897

8998
unknown_function_msg = "Can't find job"
9099
base_info = None
91-
try:
92-
base_info = await arq_job.info()
93-
except DeserializationError:
94-
unknown_function_msg = "Unknown, can't deserialize"
100+
async with self.concurrent_redis_access_sem:
101+
try:
102+
base_info = await arq_job.info()
103+
except DeserializationError:
104+
unknown_function_msg = "Unknown, can't deserialize"
95105

96106
if not base_info:
97107
base_info = JobDef(
@@ -104,7 +114,7 @@ async def get_job_by_id(self, job_id: str) -> JobInfo:
104114
)
105115

106116
job_info = JobInfo.from_base(base_info, job_id)
107-
job_info.status = await arq_job.status()
117+
job_info.status = await self._get_job_status(job_id)
108118

109119
return job_info
110120

@@ -122,17 +132,53 @@ async def abort_job(self, job_id: str) -> Optional[bool]:
122132
return None
123133

124134
async def _get_job_status(self, job_id: str) -> JobStatus:
135+
if self._cached_job_id_to_status_map is not None:
136+
return self._cached_job_id_to_status_map.get(job_id, JobStatus.not_found)
137+
125138
arq_job = ArqJob(
126139
job_id=job_id,
127140
redis=self._redis,
128141
_queue_name=self.name,
129142
_deserializer=settings.ARQ_DESERIALIZER_BY_QUEUE.get(self.name),
130143
)
131-
return await arq_job.status()
132-
133-
async def _get_job_ids(self) -> Set[str]:
134-
raw_job_ids = set(await self._redis.zrangebyscore(self.name, '-inf', 'inf'))
135-
result_keys = await self._redis.keys(f'{result_key_prefix}*')
136-
raw_job_ids |= {key[len(result_key_prefix):] for key in result_keys}
144+
async with self.concurrent_redis_access_sem:
145+
return await arq_job.status()
146+
147+
async def _get_job_id_to_status_map(self) -> Dict[str, JobStatus]:
148+
if self._cached_job_id_to_status_map is not None:
149+
return self._cached_job_id_to_status_map
150+
151+
async with self._redis.pipeline(transaction=True) as pipe:
152+
await pipe.keys(f'{ARQ_PREFIX}*:*')
153+
await pipe.zrange(self.name, withscores=True, start=0, end=-1)
154+
all_arq_keys, job_ids_with_scores = await pipe.execute()
155+
156+
# iter over lists of type [job_id, prefix];
157+
# can't use dict here because we can have multiple keys for one job and need to use the more specific one
158+
job_ids_with_prefixes = (
159+
key.decode('utf-8')[len(ARQ_PREFIX):].split(':')[::-1] for key in all_arq_keys
160+
)
137161

138-
return {job_id.decode('utf-8') if isinstance(job_id, bytes) else job_id for job_id in raw_job_ids}
162+
job_ids_to_scores = {key[0].decode('utf-8'): key[1] for key in job_ids_with_scores}
163+
job_ids_to_prefixes = dict(sorted(
164+
# not only ensure that we don't get key error but also filter out stuff that's not a client job
165+
([job_id, prefix] for job_id, prefix in job_ids_with_prefixes if prefix in PREFIX_PRIORITY),
166+
# make sure that more specific indices go after less specific ones
167+
key=lambda job_id_with_prefix: PREFIX_PRIORITY[job_id_with_prefix[-1]],
168+
))
169+
170+
self._cached_job_id_to_status_map = {
171+
job_id: self._get_job_status_from_raw_data(prefix, job_ids_to_scores.get(job_id))
172+
for job_id, prefix in job_ids_to_prefixes.items()
173+
}
174+
175+
return self._cached_job_id_to_status_map
176+
177+
def _get_job_status_from_raw_data(self, prefix: str, zscore: Optional[int]) -> JobStatus: # noqa: CFQ004
178+
if prefix == 'result':
179+
return JobStatus.complete
180+
if prefix == 'in-progress' and zscore:
181+
return JobStatus.in_progress
182+
if zscore:
183+
return JobStatus.deferred if zscore > timestamp_ms() else JobStatus.queued
184+
return JobStatus.not_found

arq_admin/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@
2626
ARQ_DESERIALIZER_BY_QUEUE = defaultdict(lambda: ARQ_DESERIALIZER)
2727

2828
ARQ_JOB_ABORT_TIMEOUT = getattr(settings, 'ARQ_JOB_ABORT_TIMEOUT', 5)
29+
30+
ARQ_MAX_CONNECTIONS = getattr(settings, 'ARQ_MAX_CONNECTIONS', 100)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
arq==0.24.0
1+
arq==0.25.0
22
Django==4.1.2

tests/test_queue.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import pytest_asyncio
7+
from arq import ArqRedis
78
from arq.constants import default_queue_name
89
from arq.jobs import DeserializationError, Job, JobStatus
910
from django.conf import settings
@@ -49,7 +50,23 @@ async def test_stats(queue: Queue) -> None:
4950

5051

5152
@pytest.mark.asyncio()
52-
@patch.object(Queue, '_get_job_ids')
53+
@pytest.mark.usefixtures('all_jobs')
54+
async def test_stats_with_running_job_wo_zscore(redis: ArqRedis, queue: Queue) -> None:
55+
await redis.zrem(queue.name, 'running_task')
56+
57+
assert await queue.get_stats() == QueueStats(
58+
name=default_queue_name,
59+
host=settings.REDIS_SETTINGS.host,
60+
port=settings.REDIS_SETTINGS.port,
61+
database=settings.REDIS_SETTINGS.database,
62+
queued_jobs=1,
63+
running_jobs=0,
64+
deferred_jobs=1,
65+
)
66+
67+
68+
@pytest.mark.asyncio()
69+
@patch.object(Queue, '_get_job_id_to_status_map')
5370
async def test_stats_with_error(mocked_get_job_ids: AsyncMock, queue: Queue) -> None:
5471
error_text = 'test error'
5572
mocked_get_job_ids.side_effect = Exception(error_text)

0 commit comments

Comments
 (0)