diff --git a/src/kirin/__init__.py b/src/kirin/__init__.py index 57dc85bdff..bc01bee243 100644 --- a/src/kirin/__init__.py +++ b/src/kirin/__init__.py @@ -1,5 +1,5 @@ # re-exports the public API of the kirin package -from . import ir as ir, types as types, lowering as lowering +from . import ir as ir, types as types, stdlib as stdlib, lowering as lowering from .exception import enable_stracetrace, disable_stracetrace __all__ = ["ir", "types", "lowering", "enable_stracetrace", "disable_stracetrace"] diff --git a/src/kirin/stdlib/__init__.py b/src/kirin/stdlib/__init__.py new file mode 100644 index 0000000000..e4505864f5 --- /dev/null +++ b/src/kirin/stdlib/__init__.py @@ -0,0 +1 @@ +from . import bits as bits diff --git a/src/kirin/stdlib/bits.py b/src/kirin/stdlib/bits.py new file mode 100644 index 0000000000..69fc6634b5 --- /dev/null +++ b/src/kirin/stdlib/bits.py @@ -0,0 +1,32 @@ +"""Bit-oriented helpers implemented as reusable Kirin kernels.""" + +from kirin.prelude import basic +from kirin.dialects import ilist + + +@basic +def _bit_length_rec(x: int, i: int) -> int: + y = x >> i + if y: + return _bit_length_rec(x, i + 1) + else: + return i + + +@basic +def bit_length(x: int) -> int: + """Return the number of bits required to represent ``x``.""" + x = abs(x) + if x == 0: + return 0 + return _bit_length_rec(x, 1) + + +@basic +def convert_bits(x: int, length: int): + """Return the low ``length`` bits of ``x`` in least-significant-bit order. Note that the return type puts the least-significant-bit in the earliest index.""" + + def _shift(i: int): + return (x >> i) & 1 + + return ilist.map(_shift, ilist.range(length)) diff --git a/test/stdlib/test_bits.py b/test/stdlib/test_bits.py new file mode 100644 index 0000000000..a93f6542f0 --- /dev/null +++ b/test/stdlib/test_bits.py @@ -0,0 +1,99 @@ +from kirin.dialects import ilist +from kirin.stdlib.bits import bit_length, convert_bits + + +def test_bit_length_negative(): + assert bit_length(-13) == 4 + + +def test_bit_length_zero(): + assert bit_length(0) == 0 + + +def test_bit_length_positive(): + assert bit_length(7) == 3 + + +def test_bit_length_large(): + x = (1 << 80) + 12345 + assert bit_length(x) == x.bit_length() + + +def test_bit_length_large_power_of_two(): + x = 1 << 80 + assert bit_length(x) == 81 + + +def test_bit_length_small(): + assert bit_length(3) == 2 + + +def test_bit_length_small_single_bit(): + assert bit_length(1) == 1 + + +def test_convert_bits_length_greater_than_bit_length(): + out = convert_bits(5, 5) + assert isinstance(out, ilist.IList) + assert out.data == [1, 0, 1, 0, 0] + + +def test_convert_bits_length_equal_to_bit_length(): + out = convert_bits(5, 3) + assert isinstance(out, ilist.IList) + assert out.data == [1, 0, 1] + + +def test_convert_bits_length_less_than_bit_length(): + out = convert_bits(13, 2) + assert isinstance(out, ilist.IList) + assert out.data == [1, 0] + + +def test_convert_bits_negative_x(): + out = convert_bits(-1, 4) + assert isinstance(out, ilist.IList) + assert out.data == [1, 1, 1, 1] + + +def test_convert_bits_negative_length(): + out = convert_bits(5, -3) + assert isinstance(out, ilist.IList) + assert out.data == [] + + +def test_convert_bits_zero_x(): + out = convert_bits(0, 4) + assert isinstance(out, ilist.IList) + assert out.data == [0, 0, 0, 0] + + +def test_convert_bits_zero_length(): + out = convert_bits(7, 0) + assert isinstance(out, ilist.IList) + assert out.data == [] + + +def test_convert_bits_small_x(): + out = convert_bits(2, 3) + assert isinstance(out, ilist.IList) + assert out.data == [0, 1, 0] + + +def test_convert_bits_small_length(): + out = convert_bits(7, 1) + assert isinstance(out, ilist.IList) + assert out.data == [1] + + +def test_convert_bits_large_x(): + x = (1 << 12) + (1 << 5) + 1 + out = convert_bits(x, 13) + assert isinstance(out, ilist.IList) + assert out.data == [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1] + + +def test_convert_bits_large_length(): + out = convert_bits(3, 10) + assert isinstance(out, ilist.IList) + assert out.data == [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]