diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5fd07c5..6e046ed 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,8 +20,8 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-13] - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + os: [ubuntu-latest, windows-latest, macos-15-intel] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"] runs-on: ${{ matrix.os }} steps: @@ -29,7 +29,7 @@ jobs: uses: actions/checkout@v3 - name: Configure Python version - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python_version }} architecture: x64 diff --git a/README.md b/README.md index b1091ea..5852a6e 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,12 @@ Simple but high performance Cython hash table mapping pre-randomized keys to `void*` values. Inspired by [Jeff Preshing](http://preshing.com/20130107/this-hash-table-is-faster-than-a-judy-array/). +All Python APIs provded by the `BloomFilter` and `PreshMap` classes are +thread-safe on both the GIL-enabled build and the free-threaded build of Python +3.14 and newer. If you use the C API or the `PreshCounter` class, you must +provide external synchronization if you use the data structures by this library +in a multithreaded environment. + [![tests](https://github.com/explosion/preshed/actions/workflows/tests.yml/badge.svg)](https://github.com/explosion/preshed/actions/workflows/tests.yml) [![pypi Version](https://img.shields.io/pypi/v/preshed.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.python.org/pypi/preshed) [![conda Version](https://img.shields.io/conda/vn/conda-forge/preshed.svg?style=flat-square&logo=conda-forge&logoColor=white)](https://anaconda.org/conda-forge/preshed) diff --git a/preshed/bloom.pxd b/preshed/bloom.pxd index 6396715..f3f178a 100644 --- a/preshed/bloom.pxd +++ b/preshed/bloom.pxd @@ -13,13 +13,15 @@ cdef struct BloomStruct: cdef class BloomFilter: cdef Pool mem cdef BloomStruct* c_bloom + # Thread-unsafe variant of __contains__ cdef inline bint contains(self, key_t item) nogil +# Low-level thread-unsafe C API. +# If you use this API and expose it to Python, you must provide external +# synchronization (e.g. with a lock or critical section). cdef void bloom_init(Pool mem, BloomStruct* bloom, key_t hcount, key_t length, uint32_t seed) except * -cdef void bloom_add(BloomStruct* bloom, key_t item) nogil - cdef bint bloom_contains(const BloomStruct* bloom, key_t item) nogil cdef void bloom_add(BloomStruct* bloom, key_t item) nogil diff --git a/preshed/bloom.pyx b/preshed/bloom.pyx index 54455a3..9b9ff4e 100644 --- a/preshed/bloom.pyx +++ b/preshed/bloom.pyx @@ -5,6 +5,10 @@ from murmurhash.mrmr cimport hash128_x86 import math from array import array +cimport cython + +from libcpp.vector cimport vector + try: import copy_reg except ImportError: @@ -37,48 +41,56 @@ cdef class BloomFilter: return cls(*params) def add(self, key_t item): - bloom_add(self.c_bloom, item) + with cython.critical_section(self): + bloom_add(self.c_bloom, item) - def __contains__(self, item): - return bloom_contains(self.c_bloom, item) + def __contains__(self, key_t item): + with cython.critical_section(self): + return bloom_contains(self.c_bloom, item) + # Requires external synchronization (e.g. a critical section) cdef inline bint contains(self, key_t item) nogil: return bloom_contains(self.c_bloom, item) def to_bytes(self): - return bloom_to_bytes(self.c_bloom) + with cython.critical_section(self): + return bloom_to_bytes(self.c_bloom) def from_bytes(self, bytes byte_string): - bloom_from_bytes(self.mem, self.c_bloom, byte_string) - return self + with cython.critical_section(self): + bloom_from_bytes(self.mem, self.c_bloom, byte_string) + return self + + def _roundtrip(self): + # Purely for testing, since this operation can't be done atomically + # without holding a critical section the entire time. + # Entering the same critical section recursively doesn't release it. + # (see cpython commit 180d417) + with cython.critical_section(self): + self.from_bytes(self.to_bytes()) cdef bytes bloom_to_bytes(const BloomStruct* bloom): - py = array("L") - py.append(bloom.hcount) - py.append(bloom.length) - py.append(bloom.seed) + # local scratch buffer + cdef vector[key_t] ret = vector[key_t]() + ret.push_back(bloom.hcount) + ret.push_back(bloom.length) + ret.push_back(bloom.seed) for i in range(bloom.length // sizeof(key_t)): - py.append(bloom.bitfield[i]) - if hasattr(py, "tobytes"): - return py.tobytes() - else: - # Python 2 :( - return py.tostring() + ret.push_back(bloom.bitfield[i]) + # copy data in the scratch buffer into a new bytes object + return (ret.data())[:3*sizeof(key_t) + bloom.length] cdef void bloom_from_bytes(Pool mem, BloomStruct* bloom, bytes data): - py = array("L") - if hasattr(py, "frombytes"): - py.frombytes(data) - else: - py.fromstring(data) - bloom.hcount = py[0] - bloom.length = py[1] - bloom.seed = py[2] + cdef char* c_data = data; + cdef key_t* i_data = c_data; + bloom.hcount = i_data[0] + bloom.length = i_data[1] + bloom.seed = i_data[2] bloom.bitfield = mem.alloc(bloom.length // sizeof(key_t), sizeof(key_t)) for i in range(bloom.length // sizeof(key_t)): - bloom.bitfield[i] = py[3+i] + bloom.bitfield[i] = i_data[3+i] cdef void bloom_init(Pool mem, BloomStruct* bloom, key_t hcount, key_t length, uint32_t seed) except *: diff --git a/preshed/maps.pxd b/preshed/maps.pxd index 291d11c..dc6bd4b 100644 --- a/preshed/maps.pxd +++ b/preshed/maps.pxd @@ -2,6 +2,10 @@ from libc.stdint cimport uint64_t from cymem.cymem cimport Pool +# Low-level thread-unsafe C API. +# If you use this API and expose it to Python, you must provide external +# synchronization (e.g. with a lock or critical section). + ctypedef uint64_t key_t @@ -24,7 +28,6 @@ cdef struct MapStruct: bint is_empty_key_set bint is_del_key_set - cdef void* map_bulk_get(const MapStruct* map_, const key_t* keys, void** values, int n) nogil @@ -46,10 +49,11 @@ cdef class PreshMap: cdef MapStruct* c_map cdef Pool mem + # these methods are thread-unsafe and require external synchronization cdef inline void* get(self, key_t key) nogil cdef void set(self, key_t key, void* value) except * - +# note: this class is thread-unsafe without external synchronization cdef class PreshMapArray: cdef Pool mem cdef MapStruct* maps diff --git a/preshed/maps.pyx b/preshed/maps.pyx index 0dcebc5..4c0d09b 100644 --- a/preshed/maps.pyx +++ b/preshed/maps.pyx @@ -35,14 +35,24 @@ cdef class PreshMap: property capacity: def __get__(self): - return self.c_map.length + cdef key_t length + with cython.critical_section(self): + # This might be atomic on some architectures + # but not everywhere, so needs a lock + length = self.c_map.length + return length def items(self): cdef key_t key cdef void* value cdef int i = 0 - while map_iter(self.c_map, &i, &key, &value): - yield key, value + while True: + with cython.critical_section(self): + it = map_iter(self.c_map, &i, &key, &value) + if it: + yield key, value + else: + break def keys(self): for key, _ in self.items(): @@ -53,37 +63,51 @@ cdef class PreshMap: yield value def pop(self, key_t key, default=None): - cdef Result result = map_get_unless_missing(self.c_map, key) - map_clear(self.c_map, key) + cdef Result result + with cython.critical_section(self): + result = map_get_unless_missing(self.c_map, key) + map_clear(self.c_map, key) if result.found: return result.value else: return default def __getitem__(self, key_t key): - cdef Result result = map_get_unless_missing(self.c_map, key) + cdef Result result + with cython.critical_section(self): + result = map_get_unless_missing(self.c_map, key) if result.found: return result.value else: return None def __setitem__(self, key_t key, size_t value): - map_set(self.mem, self.c_map, key, value) + with cython.critical_section(self): + map_set(self.mem, self.c_map, key, value) def __delitem__(self, key_t key): - map_clear(self.c_map, key) + with cython.critical_section(self): + map_clear(self.c_map, key) def __len__(self): - return self.c_map.filled + cdef key_t filled + with cython.critical_section(self): + # This might be atomic on some architectures + # but not everywhere, so needs a lock + filled = self.c_map.filled + return filled def __contains__(self, key_t key): - cdef Result result = map_get_unless_missing(self.c_map, key) + cdef Result result + with cython.critical_section(self): + result = map_get_unless_missing(self.c_map, key) return True if result.found else False def __iter__(self): for key in self.keys(): yield key + # thread-unsafe low-level API cdef inline void* get(self, key_t key) nogil: return map_get(self.c_map, key) diff --git a/preshed/tests/test_multithreaded.py b/preshed/tests/test_multithreaded.py new file mode 100644 index 0000000..5d1c4e4 --- /dev/null +++ b/preshed/tests/test_multithreaded.py @@ -0,0 +1,82 @@ +import threading +import sys +from concurrent.futures import ThreadPoolExecutor + +from preshed.bloom import BloomFilter +from preshed.maps import PreshMap + + +def run_threaded(chunks, closure): + orig_interval = sys.getswitchinterval() + sys.setswitchinterval(.0000001) + n_threads = len(chunks) + with ThreadPoolExecutor(max_workers=n_threads) as tpe: + futures = [] + b = threading.Barrier(n_threads) + for i, chunk in enumerate(chunks): + futures.append(tpe.submit(closure, b, chunk)) + [f.result() for f in futures] + sys.setswitchinterval(orig_interval) + + +def test_multithreaded_bloom_sharing(): + bf = BloomFilter(size=2**16) + n_threads = 8 + vals = list(range(0, 10000, 10)) + n_vals = len(vals) + chunk_size = n_vals//n_threads + assert chunk_size * n_threads == n_vals + chunks = [] + for i in range(0, n_vals, chunk_size): + chunks.append(vals[i: i + chunk_size]) + + def worker(b, chunk): + b.wait() + for ii in chunk: + # exercises __contains__, add, and to_bytes + # all are supposed to be thread-safe + assert ii not in bf + bf.add(ii) + assert ii in bf + bf._roundtrip() + + run_threaded(chunks, worker) + + +def test_multithreaded_map_sharing(): + h = PreshMap() + n_threads = 8 + keys = list(range(0, 10000, 10)) + vals = list(range(1, 10000, 10)) + n_vals = len(vals) + chunk_size = n_vals//n_threads + assert chunk_size * n_threads == n_vals + chunks = [] + for i in range(0, n_vals, chunk_size): + chunks.append(zip(keys[i: i + chunk_size], vals[i: i + chunk_size])) + assert len(chunks) == n_threads + + def worker(b, chunk): + b.wait() + for k, v in chunk: + # __getitem__ + assert h[k] is None + # __setitem__ + h[k] = v + # __getitem__ again + assert h[k] == v + # items() + for (kk, vv) in h.items(): + # None if another thread removed it + assert h[kk] in (vv, None) + # pop + assert h.pop(k) == v + assert h[k] is None + # __delitem__ + h[k] = v + assert h[k] == v + del h[k] + assert h[k] is None + h[k] = v + + run_threaded(chunks, worker) diff --git a/pyproject.toml b/pyproject.toml index f83923b..95dc1bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ "setuptools", - "cython>=0.28", + "cython>=3.1", "cymem>=2.0.2,<2.1.0", "murmurhash>=0.28.0,<1.1.0", ] diff --git a/setup.py b/setup.py index de0a2ad..5921f17 100755 --- a/setup.py +++ b/setup.py @@ -99,8 +99,12 @@ def setup_package(): version=about["__version__"], url=about["__uri__"], license=about["__license__"], - ext_modules=cythonize(ext_modules, language_level=2), - python_requires=">=3.6,<3.14", + ext_modules=cythonize( + ext_modules, + language_level=2, + compiler_directives={"freethreading_compatible": True}, + ), + python_requires=">=3.9,<3.15", install_requires=["cymem>=2.0.2,<2.1.0", "murmurhash>=0.28.0,<1.1.0"], classifiers=[ "Environment :: Console", @@ -111,13 +115,13 @@ def setup_package(): "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Programming Language :: Cython", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Free Threading :: 2 - Beta", "Topic :: Scientific/Engineering", ], cmdclass={"build_ext": build_ext_subclass},