Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ extern "C" {
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_TQ3_0 = 41, // TurboQuant 3-bit polar + QJL (no per-block scale)
GGML_TYPE_COUNT = 42,
GGML_TYPE_Q1_0 = 42, // PrismML 1-bit ternary (32-element blocks)
GGML_TYPE_Q1_0_G128 = 43, // PrismML 1-bit ternary (128-element blocks)
GGML_TYPE_COUNT = 44,
};

// precision
Expand Down Expand Up @@ -466,6 +468,8 @@ extern "C" {
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors (PrismML 1-bit)
GGML_FTYPE_MOSTLY_Q1_0_G128 = 28, // except 1d tensors (PrismML 1-bit g128)
};

// available tensor operations:
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ typedef struct {
} block_tq3_0;
static_assert(sizeof(block_tq3_0) == QK_TQ3_0/4 + QK_TQ3_0/8 + sizeof(ggml_half), "wrong tq3_0 block size/padding");

// PrismML Q1_0: 1-bit ternary quantization (32-element blocks)
// Each value quantized as sign bit: bit=1 → +scale, bit=0 → −scale
// scale = mean(abs(values)) per block
#define QK1_0 32
typedef struct {
ggml_half d; // scale (mean absolute value)
uint8_t qs[QK1_0 / 8]; // sign bits: 32 × 1 bit = 4 bytes
} block_q1_0;
static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0/8, "wrong q1_0 block size/padding");

// PrismML Q1_0_G128: 1-bit ternary quantization (128-element blocks)
#define QK1_0_G128 128
typedef struct {
ggml_half d; // scale
uint8_t qs[QK1_0_G128 / 8]; // sign bits: 128 × 1 bit = 16 bytes
} block_q1_0_g128;
static_assert(sizeof(block_q1_0_g128) == sizeof(ggml_half) + QK1_0_G128/8, "wrong q1_0_g128 block size/padding");

//
// Super-block quantization structures
//
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,18 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.from_float = quantize_row_tq3_0,
.nrows = 1,
},
[GGML_TYPE_Q1_0] = {
.from_float = quantize_row_q1_0,
.vec_dot = ggml_vec_dot_q1_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q1_0_G128] = {
.from_float = quantize_row_q1_0_g128,
.vec_dot = ggml_vec_dot_q1_0_g128_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
Expand Down
14 changes: 14 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ void ggml_compute_forward_add(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -1130,6 +1132,8 @@ void ggml_compute_forward_add1(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -1260,6 +1264,8 @@ void ggml_compute_forward_acc(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4349,6 +4355,8 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4626,6 +4634,8 @@ void ggml_compute_forward_set(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -4850,6 +4860,8 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down Expand Up @@ -5576,6 +5588,8 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ3_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_G128:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
Expand Down
77 changes: 77 additions & 0 deletions ggml/src/ggml-cpu/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,83 @@ void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy,
quantize_row_tq3_0_ref(x, y, k);
}

void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(k % QK1_0 == 0);
block_q1_0 * GGML_RESTRICT y = vy;
quantize_row_q1_0_ref(x, y, k);
}

void quantize_row_q1_0_g128(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(k % QK1_0_G128 == 0);
block_q1_0_g128 * GGML_RESTRICT y = vy;
quantize_row_q1_0_g128_ref(x, y, k);
}

//===================================== Q1_0 vec_dot =================================

void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;

assert(n % qk == 0);
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);

const block_q1_0 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

float sumf = 0.0f;

for (int i = 0; i < nb; i++) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
const float d1 = GGML_FP16_TO_FP32(y[i].d);

int sumi = 0;
for (int j = 0; j < QK1_0; j++) {
const int xi = ((x[i].qs[j / 8] >> (j % 8)) & 1) ? 1 : -1;
sumi += xi * (int)y[i].qs[j];
}

sumf += d0 * d1 * (float)sumi;
}

*s = sumf;
}

void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK1_0_G128;
const int nb = n / qk;

assert(n % qk == 0);
assert(nrc == 1);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);

const block_q1_0_g128 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

float sumf = 0.0f;

for (int i = 0; i < nb; i++) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);

// Each Q1_0_g128 block spans 4 Q8_0 blocks (4 × 32 = 128)
for (int k = 0; k < 4; k++) {
const float d1 = GGML_FP16_TO_FP32(y[i * 4 + k].d);
int sumi = 0;

for (int j = 0; j < QK8_0; j++) {
const int bit_index = k * QK8_0 + j;
const int xi = ((x[i].qs[bit_index / 8] >> (bit_index % 8)) & 1) ? 1 : -1;
sumi += xi * (int)y[i * 4 + k].qs[j];
}

sumf += d0 * d1 * (float)sumi;
}
}

*s = sumf;
}

//===================================== Q8_K ==============================================

void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q1_0_g128(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

Expand All @@ -55,6 +58,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
Expand Down
86 changes: 86 additions & 0 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,92 @@ size_t quantize_tq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * row_size;
}

// ====================== PrismML Q1_0 1-bit ternary quantization ======================

void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) {
assert(k % QK1_0 == 0);
const int64_t nb = k / QK1_0;

for (int64_t i = 0; i < nb; i++) {
float amax = 0.0f;
for (int j = 0; j < QK1_0; j++) {
amax += fabsf(x[i * QK1_0 + j]);
}
const float d = amax / QK1_0;
y[i].d = GGML_FP32_TO_FP16(d);

memset(y[i].qs, 0, sizeof(y[i].qs));
for (int j = 0; j < QK1_0; j++) {
if (x[i * QK1_0 + j] >= 0.0f) {
y[i].qs[j / 8] |= (1 << (j % 8));
}
}
}
}

void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK1_0 == 0);
const int64_t nb = k / QK1_0;

for (int64_t i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
for (int j = 0; j < QK1_0; j++) {
const int bit = (x[i].qs[j / 8] >> (j % 8)) & 1;
y[i * QK1_0 + j] = bit ? d : -d;
}
}
}

size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights;
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row);
quantize_row_q1_0_ref(src, dst, (int64_t)nrow * n_per_row);
return nrow * row_size;
}

// ====================== PrismML Q1_0_G128 1-bit ternary (128-element blocks) ======================

void quantize_row_q1_0_g128_ref(const float * GGML_RESTRICT x, block_q1_0_g128 * GGML_RESTRICT y, int64_t k) {
assert(k % QK1_0_G128 == 0);
const int64_t nb = k / QK1_0_G128;

for (int64_t i = 0; i < nb; i++) {
float amax = 0.0f;
for (int j = 0; j < QK1_0_G128; j++) {
amax += fabsf(x[i * QK1_0_G128 + j]);
}
const float d = amax / QK1_0_G128;
y[i].d = GGML_FP32_TO_FP16(d);

memset(y[i].qs, 0, sizeof(y[i].qs));
for (int j = 0; j < QK1_0_G128; j++) {
if (x[i * QK1_0_G128 + j] >= 0.0f) {
y[i].qs[j / 8] |= (1 << (j % 8));
}
}
}
}

void dequantize_row_q1_0_g128(const block_q1_0_g128 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK1_0_G128 == 0);
const int64_t nb = k / QK1_0_G128;

for (int64_t i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
for (int j = 0; j < QK1_0_G128; j++) {
const int bit = (x[i].qs[j / 8] >> (j % 8)) & 1;
y[i * QK1_0_G128 + j] = bit ? d : -d;
}
}
}

size_t quantize_q1_0_g128(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights;
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_0_G128, n_per_row);
quantize_row_q1_0_g128_ref(src, dst, (int64_t)nrow * n_per_row);
return nrow * row_size;
}

// ====================== "True" 2-bit (de)-quantization

void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0
GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_tq3_0_ref(const float * GGML_RESTRICT x, block_tq3_0 * GGML_RESTRICT y, int64_t k);

GGML_API void quantize_row_q1_0_ref (const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q1_0_g128_ref(const float * GGML_RESTRICT x, block_q1_0_g128 * GGML_RESTRICT y, int64_t k);

GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
Expand Down Expand Up @@ -63,6 +66,9 @@ GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float *
GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_tq3_0(const block_tq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

GGML_API void dequantize_row_q1_0 (const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q1_0_g128(const block_q1_0_g128 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
Expand All @@ -88,6 +94,9 @@ GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_REST
GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_tq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

GGML_API size_t quantize_q1_0 (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q1_0_g128(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_tq3_0,
.from_float_ref = (ggml_from_float_t) quantize_row_tq3_0_ref,
},
[GGML_TYPE_Q1_0] = {
.type_name = "q1_0",
.blck_size = QK1_0,
.type_size = sizeof(block_q1_0),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q1_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref,
},
[GGML_TYPE_Q1_0_G128] = {
.type_name = "q1_0_g128",
.blck_size = QK1_0_G128,
.type_size = sizeof(block_q1_0_g128),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q1_0_g128,
.from_float_ref = (ggml_from_float_t) quantize_row_q1_0_g128_ref,
},
};

const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
Expand Down Expand Up @@ -1397,6 +1413,8 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break;
case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break;
case GGML_FTYPE_MOSTLY_Q1_0_G128: wtype = GGML_TYPE_Q1_0_G128; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
Expand Down Expand Up @@ -7673,6 +7691,8 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_TQ3_0: result = quantize_tq3_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q1_0_G128: result = quantize_q1_0_g128(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
Expand Down
Loading