Skip to content

Commit 381c4ea

Browse files
committed
Add disk cache for bytecode to cubin
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent d5e85c3 commit 381c4ea

6 files changed

Lines changed: 333 additions & 1 deletion

File tree

changelog.d/disk-cache.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Add bytecode-to-cubin disk cache to avoid recompilation of unchanged kernels.
5+
Controlled by ``CUDA_TILE_CACHE_DIR`` and ``CUDA_TILE_CACHE_SIZE``.

docs/source/debugging.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,15 @@ debugging :class:`TileTypeError`.
3939

4040
Set ``CUDA_TILE_TEMP_DIR`` to configure the directory
4141
for storing temporary files.
42+
43+
Set ``CUDA_TILE_CACHE_DIR`` to configure the directory
44+
for the bytecode-to-cubin disk cache. Compiled cubins
45+
are cached here to avoid recompilation of unchanged
46+
kernels. Set to ``0``, ``off``, ``none``, or an empty
47+
string to disable caching. Defaults to
48+
``~/.cache/cutile-python``.
49+
50+
Set ``CUDA_TILE_CACHE_SIZE`` to configure the maximum
51+
disk cache size in bytes. Oldest entries are evicted
52+
when the cache exceeds this limit. Defaults to
53+
2 GB (2147483648).

src/cuda/tile/_cache.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import hashlib
6+
import logging
7+
import os
8+
import sqlite3
9+
import tempfile
10+
import time
11+
from pathlib import Path
12+
from typing import Optional
13+
14+
logger = logging.getLogger(__name__)
15+
16+
_CREATE_TABLE_SQL = """
17+
CREATE TABLE IF NOT EXISTS cache (
18+
key TEXT PRIMARY KEY,
19+
blob BLOB NOT NULL,
20+
blob_size INTEGER NOT NULL,
21+
atime REAL NOT NULL
22+
)
23+
"""
24+
25+
_CREATE_ATIME_INDEX_SQL = """
26+
CREATE INDEX IF NOT EXISTS idx_cache_atime ON cache(atime)
27+
"""
28+
29+
30+
_CACHE_FILENAME = "cache.db"
31+
32+
33+
def _close(conn):
34+
if conn:
35+
try:
36+
conn.close()
37+
except sqlite3.Error:
38+
pass
39+
40+
41+
def _open_db(db_path: str) -> sqlite3.Connection:
42+
conn = sqlite3.connect(db_path, timeout=5.0)
43+
conn.execute(_CREATE_TABLE_SQL)
44+
conn.execute(_CREATE_ATIME_INDEX_SQL)
45+
return conn
46+
47+
48+
def _connect(cache_dir: str) -> sqlite3.Connection:
49+
os.makedirs(cache_dir, exist_ok=True)
50+
db_path = os.path.join(cache_dir, _CACHE_FILENAME)
51+
try:
52+
return _open_db(db_path)
53+
except sqlite3.Error:
54+
logger.debug("cache db corrupt, recreating %s", db_path,
55+
exc_info=True)
56+
try:
57+
os.unlink(db_path)
58+
except OSError:
59+
pass
60+
return _open_db(db_path)
61+
62+
63+
_CACHE_VERSION = b''
64+
65+
66+
def cache_key(compiler_version: str, sm_arch: str, opt_level: int,
67+
bytecode: bytes) -> str:
68+
69+
def encode_uint(x: int):
70+
return int.to_bytes(x, 4, byteorder='big', signed=False)
71+
72+
version = compiler_version.encode()
73+
arch = sm_arch.encode()
74+
75+
h = hashlib.sha256()
76+
h.update(_CACHE_VERSION)
77+
h.update(encode_uint(len(version)))
78+
h.update(version)
79+
h.update(encode_uint(len(arch)))
80+
h.update(arch)
81+
h.update(encode_uint(opt_level))
82+
h.update(encode_uint(len(bytecode)))
83+
h.update(bytecode)
84+
return h.hexdigest()
85+
86+
87+
def cache_lookup(cache_dir: str, key: str,
88+
temp_dir: str) -> Optional[Path]:
89+
conn = None
90+
try:
91+
conn = _connect(cache_dir)
92+
row = conn.execute(
93+
"SELECT blob FROM cache WHERE key = ?", (key,)
94+
).fetchone()
95+
if row is None:
96+
return None
97+
conn.execute(
98+
"UPDATE cache SET atime = ? WHERE key = ?",
99+
(time.time(), key)
100+
)
101+
conn.commit()
102+
blob = row[0]
103+
except (sqlite3.Error, OSError):
104+
logger.debug("cache lookup failed for %s", key, exc_info=True)
105+
return None
106+
finally:
107+
_close(conn)
108+
109+
try:
110+
with tempfile.NamedTemporaryFile(dir=temp_dir, suffix=".cubin", delete=False) as f:
111+
f.write(blob)
112+
return Path(f.name)
113+
except OSError:
114+
logger.debug("cache lookup failed for %s", key, exc_info=True)
115+
return None
116+
117+
118+
def cache_store(cache_dir: str, key: str, cubin_path) -> None:
119+
conn = None
120+
try:
121+
blob = Path(cubin_path).read_bytes()
122+
conn = _connect(cache_dir)
123+
conn.execute(
124+
"INSERT OR IGNORE INTO cache"
125+
" (key, blob, blob_size, atime) VALUES (?, ?, ?, ?)",
126+
(key, blob, len(blob), time.time())
127+
)
128+
conn.commit()
129+
except (sqlite3.Error, OSError):
130+
logger.debug("cache store failed for %s", key, exc_info=True)
131+
finally:
132+
_close(conn)
133+
134+
135+
def evict_lru(cache_dir: str, size_limit: int) -> None:
136+
conn = None
137+
try:
138+
conn = _connect(cache_dir)
139+
row_limit = 100
140+
while True:
141+
res = conn.execute("""
142+
DELETE FROM cache WHERE key IN (SELECT key FROM
143+
(SELECT key, SUM(blob_size) OVER (ORDER BY atime, key) as cumul_size
144+
FROM cache ORDER BY atime, key limit ?)
145+
WHERE cumul_size <= (SELECT SUM(blob_size) - ? FROM cache)
146+
)
147+
""", (row_limit, size_limit))
148+
if res.rowcount < row_limit:
149+
break
150+
row_limit *= 10
151+
conn.commit()
152+
except sqlite3.Error:
153+
logger.debug("cache evict failed", exc_info=True)
154+
finally:
155+
_close(conn)

src/cuda/tile/_compile.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from cuda.tile._passes.check_ampere_fp8 import check_ampere_fp8
5252
from cuda.tile._passes.dce import dead_code_elimination_pass
5353
from cuda.tile._passes.token_order import token_order_pass
54+
from cuda.tile._cache import cache_key, cache_lookup, cache_store, evict_lru
5455
from cuda.tile._ir2bytecode import generate_bytecode_for_kernel
5556
from cuda.tile._version import __version__ as cutile_version
5657
import cuda.tile._bytecode as bc
@@ -248,6 +249,21 @@ def compile_tile(pyfunc,
248249
print("Can't print MLIR because the internal extension is missing. "
249250
"This is currently not a public feature.", file=sys.stderr)
250251

252+
# Check disk cache before invoking tileiras
253+
cache_dir = context.config.cache_dir
254+
compiler_ver = _get_compiler_version_string()
255+
key = None
256+
if cache_dir is None:
257+
logger.debug("disk cache disabled: context.config.cache_dir is not set")
258+
elif compiler_ver is None:
259+
logger.warning("disk cache disabled: compiler version is unknown")
260+
else:
261+
opt_level = compiler_options.specialize_for_target(sm_arch).opt_level
262+
key = cache_key(compiler_ver, sm_arch, opt_level, bytecode_buf)
263+
cubin_path = cache_lookup(cache_dir, key, context.config.temp_dir)
264+
if cubin_path is not None:
265+
return TileLibrary(func_ir.name, cubin_path, bytecode_buf, func_ir.body)
266+
251267
# Compile MLIR module and generate cubin
252268
with tempfile.NamedTemporaryFile(suffix='.bytecode', prefix=func_ir.name,
253269
dir=context.config.temp_dir, delete=False) as f:
@@ -265,6 +281,10 @@ def compile_tile(pyfunc,
265281

266282
raise e
267283

284+
if cache_dir is not None and key is not None:
285+
cache_store(cache_dir, key, cubin_file)
286+
evict_lru(cache_dir, context.config.cache_size_limit)
287+
268288
return TileLibrary(func_ir.name, cubin_file, bytecode_buf, func_ir.body)
269289

270290

@@ -436,6 +456,13 @@ def _try_get_compiler_version(compiler_bin) -> Optional[str]:
436456
return None
437457

438458

459+
@cache
460+
def _get_compiler_version_string() -> str | None:
461+
binary = _find_compiler_bin()
462+
version = _try_get_compiler_version(binary.path)
463+
return version
464+
465+
439466
@cache
440467
def get_sm_arch() -> str:
441468
major, minor = get_compute_capability()

src/cuda/tile/_context.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import atexit
66
import os
77
import shutil
8+
import sys
89
import tempfile
910
from dataclasses import dataclass
1011
from typing import Optional
@@ -16,14 +17,18 @@ class TileContextConfig:
1617
log_keys: list[str]
1718
compiler_timeout_sec: Optional[int]
1819
enable_crash_dump: bool
20+
cache_dir: Optional[str]
21+
cache_size_limit: int
1922

2023

2124
def init_context_config_from_env():
2225
config = TileContextConfig(
2326
temp_dir=get_temp_dir_from_env(),
2427
log_keys=get_log_keys_from_env(),
2528
compiler_timeout_sec=get_compile_timeout_from_env(),
26-
enable_crash_dump=get_enable_crash_dump_from_env()
29+
enable_crash_dump=get_enable_crash_dump_from_env(),
30+
cache_dir=get_cache_dir_from_env(),
31+
cache_size_limit=get_cache_size_limit_from_env()
2732
)
2833
return config
2934

@@ -71,3 +76,20 @@ def get_enable_crash_dump_from_env() -> bool:
7176
key = "CUDA_TILE_ENABLE_CRASH_DUMP"
7277
env = os.environ.get(key, "0").lower()
7378
return env in ("1", "true", "yes", "on")
79+
80+
81+
def get_cache_dir_from_env() -> Optional[str]:
82+
home_cache = os.path.join(os.path.expanduser("~"), ".cache")
83+
if sys.platform == "win32":
84+
base = os.environ.get("LOCALAPPDATA", home_cache)
85+
else:
86+
base = os.environ.get("XDG_CACHE_HOME", home_cache)
87+
default = os.path.join(base, "cutile-python")
88+
env = os.environ.get("CUDA_TILE_CACHE_DIR", default)
89+
if env.strip().lower() in ("0", "off", "none", ""):
90+
return None
91+
return env
92+
93+
94+
def get_cache_size_limit_from_env() -> int:
95+
return int(os.environ.get("CUDA_TILE_CACHE_SIZE", 1 << 31)) # 2GB

test/test_cache.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sqlite3
6+
import time
7+
8+
import pytest
9+
10+
from cuda.tile._cache import cache_key, cache_lookup, cache_store, evict_lru
11+
12+
13+
def test_cache_key_equal():
14+
k1 = cache_key("v1", "sm_90", 3, b"data")
15+
k2 = cache_key("v1", "sm_90", 3, b"data")
16+
assert k1 == k2
17+
18+
19+
def test_cache_key_differs():
20+
base = cache_key("v1", "sm_90", 3, b"data")
21+
assert cache_key("v2", "sm_90", 3, b"data") != base
22+
assert cache_key("v1", "sm_80", 3, b"data") != base
23+
assert cache_key("v1", "sm_90", 2, b"data") != base
24+
assert cache_key("v1", "sm_90", 3, b"other") != base
25+
26+
27+
def make_cubin(tmp_path, name, content):
28+
src = tmp_path / name
29+
src.write_bytes(content)
30+
return src
31+
32+
33+
@pytest.fixture
34+
def cache_env(tmp_path):
35+
cache_dir = str(tmp_path / "cache")
36+
return cache_dir, tmp_path
37+
38+
39+
def test_store_then_lookup(cache_env):
40+
cache_dir, tmp_path = cache_env
41+
key = cache_key("v1", "sm_90", 3, b"data")
42+
content = b"\x7fELF_fake_cubin_data"
43+
44+
cache_store(cache_dir, key, make_cubin(tmp_path, "kernel.cubin", content))
45+
46+
result = cache_lookup(cache_dir, key, str(tmp_path))
47+
assert result is not None
48+
assert result.read_bytes() == content
49+
50+
51+
def test_lookup_updates_atime(cache_env):
52+
cache_dir, tmp_path = cache_env
53+
key = cache_key("v1", "sm_90", 3, b"data")
54+
55+
cache_store(cache_dir, key, make_cubin(tmp_path, "kernel.cubin", b"data"))
56+
57+
# Manually set old atime in DB
58+
import os
59+
db_path = os.path.join(cache_dir, "cache.db")
60+
old_time = time.time() - 1000
61+
conn = sqlite3.connect(db_path)
62+
conn.execute("UPDATE cache SET atime = ? WHERE key = ?", (old_time, key))
63+
conn.commit()
64+
conn.close()
65+
66+
cache_lookup(cache_dir, key, str(tmp_path))
67+
68+
conn = sqlite3.connect(db_path)
69+
atime = conn.execute(
70+
"SELECT atime FROM cache WHERE key = ?", (key,)
71+
).fetchone()[0]
72+
conn.close()
73+
assert atime > old_time
74+
75+
76+
def test_lookup_miss(cache_env):
77+
cache_dir, _ = cache_env
78+
79+
result = cache_lookup(cache_dir, "a" * 64, str(cache_dir))
80+
assert result is None
81+
82+
83+
def test_evict_lru(cache_env):
84+
cache_dir, tmp_path = cache_env
85+
import os
86+
db_path = os.path.join(cache_dir, "cache.db")
87+
88+
# Populate 5 entries (1000 bytes each, 5000 total)
89+
keys = []
90+
for i in range(5):
91+
key = cache_key(str(i), "sm_90", 3, b"data")
92+
keys.append(key)
93+
cache_store(cache_dir, key,
94+
make_cubin(tmp_path, f"k{i}.cubin", b"x" * 1000))
95+
96+
# Set controlled atimes so eviction order is deterministic
97+
conn = sqlite3.connect(db_path)
98+
for i, key in enumerate(keys):
99+
conn.execute(
100+
"UPDATE cache SET atime = ? WHERE key = ?",
101+
(float(i), key)
102+
)
103+
conn.commit()
104+
conn.close()
105+
106+
# Evict to keep 3000 bytes; newest 3 survive (indices 2, 3, 4)
107+
evict_lru(cache_dir, 3000)
108+
109+
remaining = [k for k in keys
110+
if cache_lookup(cache_dir, k, str(tmp_path)) is not None]
111+
assert remaining == keys[2:]

0 commit comments

Comments
 (0)