diff --git a/CHANGELOG.md b/CHANGELOG.md index 226f098..5a0675f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- support `BFloat16` https://github.com/plausible/ch/pull/321 - **Breaking:** replace DBConnection with NimblePool. - **Breaking:** require Elixir 1.18 or later for the built-in `JSON` module and Erlang/OTP 28 or later for `:zstd`. - **Breaking:** `Ch.start_link/1` no longer accepts DBConnection options or connection-level ClickHouse options such as `:database`, `:username`, `:password`, `:settings`, `:timeout`, `:scheme`, `:hostname`, `:port`, and `:transport_opts`. Use `:url` for the endpoint, pass ClickHouse settings per query with `Ch.query/4`'s `:settings` option, and pass ClickHouse/database/auth headers per query with `:headers`. diff --git a/lib/ch.ex b/lib/ch.ex index 92f7e41..e8a8704 100644 --- a/lib/ch.ex +++ b/lib/ch.ex @@ -484,6 +484,8 @@ defmodule Ch do def type(unquote(:"f#{size}")), do: :float end + def type(:bf16), do: :float + def type({:decimal, _p, _s}), do: :decimal for size <- [32, 64, 128, 256] do @@ -547,6 +549,8 @@ defmodule Ch do def cast(value, unquote(:"f#{size}")), do: Ecto.Type.cast(:float, value) end + def cast(value, :bf16), do: Ecto.Type.cast(:float, value) + def cast(value, {:decimal = type, _p, _s}), do: Ecto.Type.cast(type, value) for size <- [32, 64, 128, 256] do diff --git a/lib/ch/row_binary.ex b/lib/ch/row_binary.ex index 0f4a719..cf8ec6a 100644 --- a/lib/ch/row_binary.ex +++ b/lib/ch/row_binary.ex @@ -102,7 +102,8 @@ defmodule Ch.RowBinary do :ipv4, :ipv6, :point, - :nothing + :nothing, + :bf16 ], do: t @@ -258,12 +259,18 @@ defmodule Ch.RowBinary do type = :"f#{size}" def encode(unquote(type), f) when is_number(f) do - <> + <> end def encode(unquote(type), nil), do: <<0::unquote(size)>> end + def encode(:bf16, f) when is_number(f) do + <> + end + + def encode(:bf16, nil), do: <<0::16>> + def encode({:decimal, precision, scale}, decimal) do type = case decimal_size(precision) do @@ -474,6 +481,26 @@ defmodule Ch.RowBinary do end end + defp float_to_bf16(f) do + <> = <> + + upper = bits >>> 16 + lower = bits &&& 0xFFFF + + if lower > 0x8000 or (lower == 0x8000 and (upper &&& 1) == 1) do + upper + 1 + else + upper + end + end + + defp bfloat16_to_float(bits) when (bits &&& 0x7F80) == 0x7F80, do: nil + + defp bfloat16_to_float(bits) do + <> = <> + f + end + defp encode_varint_cont(i) when i < 128, do: <> defp encode_varint_cont(i) do @@ -725,7 +752,8 @@ defmodule Ch.RowBinary do :ipv4, :ipv6, :point, - :nothing + :nothing, + :bf16 ], do: t @@ -947,7 +975,8 @@ defmodule Ch.RowBinary do uuid: 0x1D, ipv4: 0x28, ipv6: 0x29, - boolean: 0x2D + boolean: 0x2D, + bf16: 0x31 ] # TODO compile inline? @@ -1179,6 +1208,9 @@ defmodule Ch.RowBinary do %{pattern: quote(do: <>), value: quote(do: f)}, %{pattern: quote(do: <<_nan_or_inf::64>>), value: quote(do: nil)} ], + bf16: [ + %{pattern: quote(do: <>), value: quote(do: bfloat16_to_float(bits))} + ], uuid: %{ pattern: quote(do: <>), value: quote(do: <>) @@ -1273,6 +1305,9 @@ defmodule Ch.RowBinary do :f64 -> decode_f64_decode_rows(bin, types_rest, row, rows, types) + :bf16 -> + decode_bf16_decode_rows(bin, types_rest, row, rows, types) + :string -> decode_string_decode_rows(bin, types_rest, row, rows, types) diff --git a/lib/ch/types.ex b/lib/ch/types.ex index fc28032..d1d2cde 100644 --- a/lib/ch/types.ex +++ b/lib/ch/types.ex @@ -16,6 +16,7 @@ defmodule Ch.Types do for size <- [32, 64] do {"Float#{size}", :"f#{size}", []} end, + {"BFloat16", :bf16, []}, {"Array", :array, [:type]}, {"Tuple", :tuple, [:maybe_named_column]}, {"Variant", :variant, [:type]}, diff --git a/test/ch/bfloat16_test.exs b/test/ch/bfloat16_test.exs new file mode 100644 index 0000000..9e57d6b --- /dev/null +++ b/test/ch/bfloat16_test.exs @@ -0,0 +1,159 @@ +defmodule Ch.BFloat16Test do + use ExUnit.Case, async: true + use ExUnitProperties + + import Bitwise + + @moduletag :bf16 + + @bf16_edges [ + 0x0000, + 0x8000, + 0x0001, + 0x8001, + 0x007F, + 0x807F, + 0x0080, + 0x8080, + 0x3F80, + 0xBF80, + 0x3FE0, + 0x7F7F, + 0xFF7F + ] + + setup do + {:ok, pool: start_supervised!(Ch)} + end + + property "plain finite values", %{pool: pool} do + check all value <- bounded_bfloat16() do + assert Ch.query!( + pool, + "select #{Float.to_string(value)}::BFloat16", + _no_params = %{} + ).rows == + [[value]] + end + end + + property "finite params round-trip through ClickHouse casts", %{ + pool: pool + } do + check all value <- bounded_bfloat16() do + assert Ch.query!( + pool, + "select {value:BFloat16} as value", + %{"value" => value} + ).rows == + [[value]] + end + end + + test "special values decode as nil", %{pool: pool} do + assert Ch.query!( + pool, + "select 'nan'::BFloat16, 'inf'::BFloat16, '-inf'::BFloat16", + %{} + ).rows == + [[nil, nil, nil]] + end + + test "RowBinary edge values round-trip through ClickHouse", %{ + pool: pool + } do + table = "bf16_edges" + insert = "insert into bf16_edges (idx, bf16) format RowBinary" + + create_table(pool, table) + + on_exit(fn -> + Help.query!("drop table if exists {table:Identifier}", %{"table" => table}) + end) + + values = Enum.map(@bf16_edges, &bfloat16_to_float/1) + + assert_rowbinary_round_trip(pool, table, insert, values) + end + + property "finite RowBinary values round-trip through ClickHouse", %{ + pool: pool + } do + table = "bf16_finite" + insert = "insert into bf16_finite (idx, bf16) format RowBinary" + + create_table(pool, table) + + on_exit(fn -> + Help.query!("drop table if exists {table:Identifier}", %{"table" => table}) + end) + + check all bits <- list_of(finite_bfloat16_bits(), length: 20) do + create_table(pool, table, if_not_exists: true) + + values = Enum.map(bits, &bfloat16_to_float/1) + + assert_rowbinary_round_trip(pool, table, insert, values) + end + end + + defp bounded_bfloat16 do + integer(-1_000_000..1_000_000) + |> map(&(&1 / 16)) + |> map(&float_to_bfloat16/1) + |> map(&bfloat16_to_float/1) + end + + defp finite_bfloat16_bits do + gen all sign <- integer(0..1), + exponent <- integer(0..0xFE), + fraction <- integer(0..0x7F) do + sign <<< 15 ||| exponent <<< 7 ||| fraction + end + end + + defp float_to_bfloat16(float) do + <> = <> + + upper = bits >>> 16 + lower = bits &&& 0xFFFF + + if lower > 0x8000 or (lower == 0x8000 and (upper &&& 1) == 1) do + upper + 1 + else + upper + end + end + + defp bfloat16_to_float(bits) do + <> = <> + float + end + + defp create_table(pool, table, opts \\ []) do + exists = if opts[:if_not_exists], do: " if not exists", else: "" + + Ch.query!( + pool, + "create table#{exists} {table:Identifier} (idx UInt8, bf16 BFloat16) engine Memory", + %{"table" => table} + ) + end + + defp assert_rowbinary_round_trip(pool, table, insert, values) do + Ch.query!(pool, "truncate table {table:Identifier}", %{"table" => table}) + + rows = + values + |> Enum.with_index() + |> Enum.map(fn {value, idx} -> [idx, value] end) + + rowbinary = Ch.RowBinary.encode_rows(rows, ["UInt8", "BFloat16"]) + + assert %Ch.Result{names: nil, rows: nil, data: nil} = + Ch.query!(pool, [insert, ?\n | rowbinary]) + + assert Ch.query!(pool, "select bf16 from {table:Identifier} order by idx", %{"table" => table}).rows == + Enum.map(values, &[&1]) + end +end diff --git a/test/ch/ecto_type_test.exs b/test/ch/ecto_type_test.exs index 32cabcb..ef68aab 100644 --- a/test/ch/ecto_type_test.exs +++ b/test/ch/ecto_type_test.exs @@ -263,6 +263,24 @@ defmodule Ch.EctoTypeTest do end end + test "BFloat16" do + assert {:parameterized, {Ch, :bf16}} = + type = Ecto.ParameterizedType.init(Ch, type: unquote("BFloat16")) + + assert Ecto.Type.type(type) == :float + assert Ecto.Type.format(type) == "#Ch" + + assert {:ok, 1.0} = Ecto.Type.cast(type, 1.0) + assert {:ok, 1.0} = Ecto.Type.cast(type, 1) + assert {:ok, 1.0} = Ecto.Type.cast(type, "1.0") + assert {:ok, nil} = Ecto.Type.cast(type, nil) + + assert :error = Ecto.Type.cast(type, "asdf") + + assert {:ok, 1.0} = Ecto.Type.dump(type, 1.0) + assert {:ok, 1.0} = Ecto.Type.load(type, 1.0) + end + test "Date" do assert {:parameterized, {Ch, :date}} = type = Ecto.ParameterizedType.init(Ch, type: "Date") diff --git a/test/ch/row_binary_test.exs b/test/ch/row_binary_test.exs index ad0bb11..d7953d8 100644 --- a/test/ch/row_binary_test.exs +++ b/test/ch/row_binary_test.exs @@ -45,6 +45,7 @@ defmodule Ch.RowBinaryTest do {:i256, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF}, {:f32, 1.2345678806304932}, {:f64, 1.234567898762738492830000503040030202020433}, + {:bf16, 1.75}, {:date, ~D[2022-01-01]}, {:date, ~D[2042-01-01]}, {:date, ~D[1970-01-01]}, @@ -76,6 +77,7 @@ defmodule Ch.RowBinaryTest do 0, 1.234567898762738492830000503040030202020433 ]}, + {{:array, :bf16}, [-1.75, 0, 1.75]}, {{:array, :date}, [~D[2022-01-01], ~D[2042-01-01], ~D[1970-01-01]]}, {{:array, :datetime}, [~N[1970-01-01 12:23:34], ~N[2022-01-01 22:12:59], ~N[2042-01-01 04:23:01]]}, @@ -187,6 +189,7 @@ defmodule Ch.RowBinaryTest do assert encode(:i64, nil) == <<0, 0, 0, 0, 0, 0, 0, 0>> assert encode(:f32, nil) == <<0, 0, 0, 0>> assert encode(:f64, nil) == <<0, 0, 0, 0, 0, 0, 0, 0>> + assert encode(:bf16, nil) == <<0, 0>> assert encode(:boolean, nil) == 0 assert encode({:array, :string}, nil) == 0 assert encode(:date, nil) == <<0, 0>> @@ -271,6 +274,7 @@ defmodule Ch.RowBinaryTest do {"Int256", :i256}, {"Float32", :f32}, {"Float64", :f64}, + {"BFloat16", :bf16}, {"Decimal(9, 4)", {:decimal, _size = 32, _scale = 4}}, {"Decimal(23, 11)", {:decimal, _size = 128, _scale = 11}}, {"Bool", :boolean}, diff --git a/test/ch/types_test.exs b/test/ch/types_test.exs index 9c8b236..a17e95f 100644 --- a/test/ch/types_test.exs +++ b/test/ch/types_test.exs @@ -25,6 +25,7 @@ defmodule Ch.TypesTest do assert decode("Float32") == :f32 assert decode("Float64") == :f64 + assert decode("BFloat16") == :bf16 assert decode("Date") == :date assert decode("DateTime") == :datetime diff --git a/test/test_helper.exs b/test/test_helper.exs index f4c80f8..f90f24e 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -23,8 +23,8 @@ exclude = if version >= "25" do [] else - # Time, Variant, JSON, and Dynamic types are not supported in older ClickHouse versions we have in the CI - [:time, :variant, :json, :dynamic] + # Time, Variant, JSON, BFloat16, and Dynamic types are not supported in older ClickHouse versions we have in the CI + [:time, :variant, :json, :bf16, :dynamic] end assert_receive_timeout =