Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kirin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# re-exports the public API of the kirin package
from . import ir as ir, types as types, lowering as lowering
from . import ir as ir, types as types, stdlib as stdlib, lowering as lowering
from .exception import enable_stracetrace, disable_stracetrace

__all__ = ["ir", "types", "lowering", "enable_stracetrace", "disable_stracetrace"]
1 change: 1 addition & 0 deletions src/kirin/stdlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import bits as bits
32 changes: 32 additions & 0 deletions src/kirin/stdlib/bits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Bit-oriented helpers implemented as reusable Kirin kernels."""

from kirin.prelude import basic
from kirin.dialects import ilist


@basic
def _bit_length_rec(x: int, i: int) -> int:
y = x >> i
if y:
return _bit_length_rec(x, i + 1)
else:
return i


@basic
def bit_length(x: int) -> int:
"""Return the number of bits required to represent ``x``."""
x = abs(x)
if x == 0:
return 0
return _bit_length_rec(x, 1)


@basic
def convert_bits(x: int, length: int):
"""Return the low ``length`` bits of ``x`` in least-significant-bit order. Note that the return type puts the least-significant-bit in the earliest index."""

def _shift(i: int):
return (x >> i) & 1

return ilist.map(_shift, ilist.range(length))
99 changes: 99 additions & 0 deletions test/stdlib/test_bits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from kirin.dialects import ilist
from kirin.stdlib.bits import bit_length, convert_bits


def test_bit_length_negative():
assert bit_length(-13) == 4


def test_bit_length_zero():
assert bit_length(0) == 0


def test_bit_length_positive():
assert bit_length(7) == 3


def test_bit_length_large():
x = (1 << 80) + 12345
assert bit_length(x) == x.bit_length()


def test_bit_length_large_power_of_two():
x = 1 << 80
assert bit_length(x) == 81


def test_bit_length_small():
assert bit_length(3) == 2


def test_bit_length_small_single_bit():
assert bit_length(1) == 1


def test_convert_bits_length_greater_than_bit_length():
out = convert_bits(5, 5)
assert isinstance(out, ilist.IList)
assert out.data == [1, 0, 1, 0, 0]


def test_convert_bits_length_equal_to_bit_length():
out = convert_bits(5, 3)
assert isinstance(out, ilist.IList)
assert out.data == [1, 0, 1]


def test_convert_bits_length_less_than_bit_length():
out = convert_bits(13, 2)
assert isinstance(out, ilist.IList)
assert out.data == [1, 0]


def test_convert_bits_negative_x():
out = convert_bits(-1, 4)
assert isinstance(out, ilist.IList)
assert out.data == [1, 1, 1, 1]


def test_convert_bits_negative_length():
out = convert_bits(5, -3)
assert isinstance(out, ilist.IList)
assert out.data == []


def test_convert_bits_zero_x():
out = convert_bits(0, 4)
assert isinstance(out, ilist.IList)
assert out.data == [0, 0, 0, 0]


def test_convert_bits_zero_length():
out = convert_bits(7, 0)
assert isinstance(out, ilist.IList)
assert out.data == []


def test_convert_bits_small_x():
out = convert_bits(2, 3)
assert isinstance(out, ilist.IList)
assert out.data == [0, 1, 0]


def test_convert_bits_small_length():
out = convert_bits(7, 1)
assert isinstance(out, ilist.IList)
assert out.data == [1]


def test_convert_bits_large_x():
x = (1 << 12) + (1 << 5) + 1
out = convert_bits(x, 13)
assert isinstance(out, ilist.IList)
assert out.data == [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1]


def test_convert_bits_large_length():
out = convert_bits(3, 10)
assert isinstance(out, ilist.IList)
assert out.data == [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
Loading