diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 54f998649..9d2180873 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -627,6 +627,278 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +// ======== Q4_1x4x2 ==================== + +static void dump_block_q4_1(const block_q4_1 * b, int i) { + HEX_VERBOSE("ggml-hex: repack q4_1 %d: %d %d %d %d ... %d %d %d %d : %.6f %.6f\n", i, unpack_q4(b->qs[0]).v[0], + unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1], + unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1], + GGML_FP16_TO_FP32(b->d), GGML_FP16_TO_FP32(b->m)); +} + +static void dump_packed_block_q4_1x4x2(const uint8_t * v, unsigned int i, size_t k) { + static const int qk = QK_Q4_1x4x2; + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded) + const int drow_size = (k / 32) * 2; // fp16 + + const uint8_t * v_q = v + 0; // quants first + const uint8_t * v_d = v + qrow_size; // then scales + const uint8_t * v_m = v + qrow_size + drow_size; // then mins + + const uint8_t * q = v_q + i * qblk_size; + const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); + const ggml_half * m = (const ggml_half *) (v_m + i * dblk_size); + + HEX_VERBOSE("ggml-hex: repack q4_1x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f ... %.6f %.6f\n", i, + unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0], + unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0], + unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0], + GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(m[0]), GGML_FP16_TO_FP32(d[3]), GGML_FP16_TO_FP32(m[3])); +} + +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + const int drow_size = (k / 32) * 2; // fp16 (not padded) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales + uint8_t * y_m = y + qrow_size + drow_size;// then mins + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; ++j) dump_block_q4_1(&x[i * 8 + j], j); + } + } + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; // unpacked quants + for (int j = 0; j < 8; ++j) { + unpack_q4_1_quants(qs, &x[i * 8 + j], j); + } + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = (qs[j + 128] << 4) | qs[j]; + } + } + + // Repack the scales and mins + for (int i = 0; i < nb; i++) { + ggml_half * d = (ggml_half *) (y_d + i * dblk_size); + ggml_half * m = (ggml_half *) (y_m + i * dblk_size); + for (int j = 0; j < 8; ++j) { + d[j] = x[i * 8 + j].d; + m[j] = x[i * 8 + j].m; + } + } + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q4_1x4x2(y, i, k); + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + const int drow_size = (k / 32) * 2; // fp16 (not padded) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales + const uint8_t * y_m = y + qrow_size + drow_size; // then mins + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q4_1x4x2(y, i, k); + } + } + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; // unpacked quants + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + qs[j] = q[j] & 0xf; + qs[j + 128] = q[j] >> 4; + } + + for (int j = 0; j < 8; ++j) { + pack_q4_1_quants(&x[i * 8 + j], qs, j); + } + } + + // Unpack the scales and mins + for (int i = 0; i < nb; i++) { + const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); + const ggml_half * m = (const ggml_half *) (y_m + i * dblk_size); + for (int j = 0; j < 8; ++j) { + x[i * 8 + j].d = d[j]; + x[i * 8 + j].m = m[j]; + } + } + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; ++j) dump_block_q4_1(&x[i * 8 + j], j); + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + // Init the quants such that they unpack into zeros + uint8_t qs[QK_Q4_1x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; ++j) { + pack_q4_1_quants(&x[i * 8 + j], qs, j); + } + } + + // Init scales and mins + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; ++j) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +// repack q4_1 data into q4_1x4x2 tensor +static void repack_q4_1_q4_1x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4_1x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// repack q4_1x4x2 tensor into q4_1 data +static void repack_q4_1x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); + size_t row_size_rp = row_size * 2; + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_pd, 0, row_size_pd); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // repack q4x4x2 tensor into q4_0 data static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { int64_t nrows = ggml_nrows(t); @@ -1377,6 +1649,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4_1x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1413,6 +1691,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1849,6 +2133,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: if (src0->ne[0] % 32) { @@ -3206,6 +3491,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 950d836ad..21bd4050a 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -69,27 +69,45 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +struct htp_act_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + const uint8_t * data_src1; + uint8_t * data_dst; + + size_t src0_row_size; + size_t src1_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t src1_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + int nc; +}; + +static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; - - - - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " "%zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } - const float alpha = ((const float *) (op_params))[2]; - const float limit = ((const float *) (op_params))[3]; + const float alpha = ((const float *) (actx->octx->op_params))[2]; + const float limit = ((const float *) (actx->octx->op_params))[3]; + + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { @@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, } -static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + // nc/ne0 matches. + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; // In gelu = x*sigmoid(x*1.702) - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - - -static void unary_silu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; + + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // silu = x * sigmoid(x) - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, static const float GELU_COEF_A = 0.044715f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; - - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - static int execute_op_activations_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; @@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { switch (octx->op) { case HTP_OP_UNARY_SILU: - act_op_func = unary_silu_f32; + act_op_func = (worker_callback_t)unary_silu_f32_per_thread; op_type = "silu-f32"; break; case HTP_OP_GLU_SWIGLU: - act_op_func = glu_swiglu_f32; + act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread; op_type = "swiglu-f32"; break; case HTP_OP_GLU_SWIGLU_OAI: - act_op_func = glu_swiglu_oai_f32; + act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread; op_type = "swiglu-oai-f32"; break; case HTP_OP_UNARY_GELU: - act_op_func = unary_gelu_f32; + act_op_func = (worker_callback_t)unary_gelu_f32_per_thread; op_type = "gelu-f32"; break; case HTP_OP_GLU_GEGLU: - act_op_func = glu_geglu_f32; + act_op_func = (worker_callback_t)glu_geglu_f32_per_thread; op_type = "geglu-f32"; break; default: @@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + // Prepare context + struct htp_act_context actx; + actx.octx = octx; + + actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + actx.src0_row_size = src0_row_size; + actx.src1_row_size = src1_row_size; + actx.dst_row_size = dst_row_size; + + actx.src0_row_size_aligned = src0_row_size_aligned; + actx.src1_row_size_aligned = src1_row_size_aligned; + actx.dst_row_size_aligned = dst_row_size_aligned; + + actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2; + actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2; + actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2; + + actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned; + actx.src0_nrows = src0_nrows; + + actx.nc = dst->ne[0]; + + // Pointers and GLU logic + const uint8_t * data_src0 = (const uint8_t *) src0->data; + const uint8_t * data_src1 = (const uint8_t *) src1->data; + + if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + const int32_t swapped = octx->op_params[1]; + data_src1 = data_src0; + actx.src1_row_size = actx.src0_row_size; + + size_t nc_in_bytes = actx.nc * SIZEOF_FP32; + if (swapped) { + data_src0 += nc_in_bytes; + } else { + data_src1 += nc_in_bytes; + } + } + + actx.data_src0 = data_src0; + actx.data_src1 = data_src1; + actx.data_dst = (uint8_t *) dst->data; + + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + return HTP_STATUS_OK; } int op_activations(struct htp_ops_context * octx) { diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index a657cd2dc..bf24bbda7 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -15,6 +15,13 @@ #include "htp-ops.h" #include "hvx-utils.h" +struct get_rows_context { + struct htp_ops_context * octx; + uint32_t src1_nrows_per_thread; + struct fastdiv_values get_rows_div_ne10; + struct fastdiv_values get_rows_div_ne10_ne11; +}; + #define get_rows_preamble \ const uint32_t ne00 = octx->src0.ne[0]; \ const uint32_t ne01 = octx->src0.ne[1]; \ @@ -39,20 +46,22 @@ \ const uint32_t nr = ne10 * ne11 * ne12; -static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; - const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } - - return HTP_STATUS_OK; -} - -static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_get_rows(struct htp_ops_context * octx) { @@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + struct get_rows_context grctx; + grctx.octx = octx; + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index d1ddb0ecb..350ab9d96 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q, dmlink(q->tail, desc); q->tail = desc; - // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); + // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } @@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dptr = q->dptr[q->pop_idx]; - // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); + // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); q->pop_idx = (q->pop_idx + 1) & q->idx_mask; return dptr; } +static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +static inline bool dma_queue_empty(dma_queue * q) { + return q->push_idx == q->pop_idx; +} + +static inline uint32_t dma_queue_depth(dma_queue * q) { + return (q->push_idx - q->pop_idx) & q->idx_mask; +} + +static inline uint32_t dma_queue_capacity(dma_queue * q) { + return q->capacity; +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 25403bb11..b1316f80d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -35,7 +35,9 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_Q8_1 = 9, HTP_TYPE_I32 = 26, HTP_TYPE_I64 = 27, HTP_TYPE_MXFP4 = 39, @@ -79,8 +81,12 @@ static inline size_t htp_t_block_size(uint32_t t) { return 1; case HTP_TYPE_Q4_0: return QK4_0; + case HTP_TYPE_Q4_1: + return QK4_1; case HTP_TYPE_Q8_0: return QK8_0; + case HTP_TYPE_Q8_1: + return QK8_1; case HTP_TYPE_MXFP4: return QK_MXFP4; default: @@ -97,8 +103,12 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 2; case HTP_TYPE_Q4_0: return sizeof(block_q4_0); + case HTP_TYPE_Q4_1: + return sizeof(block_q4_1); case HTP_TYPE_Q8_0: return sizeof(block_q8_0); + case HTP_TYPE_Q8_1: + return sizeof(block_q8_1); case HTP_TYPE_MXFP4: return sizeof(block_mxfp4); default: @@ -109,7 +119,9 @@ static inline size_t htp_type_nbytes(uint32_t t) { // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q4_1x4x2 256 // 4x Q4_1 blocks packed with next 4x Q4_1 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_Q8_1x4x2 256 // 4x Q8_1 blocks concat with next 4x Q8_1 blocks #define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks #define HTP_MAX_DIMS 4 diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f1ad24dbf..127ab1d66 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -44,32 +44,6 @@ struct htp_ops_context { uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; - struct fastdiv_values src0_div1; // fastdiv values for ne1 - struct fastdiv_values src0_div2; // fastdiv values for ne2 - struct fastdiv_values src0_div3; // fastdiv values for ne3 - struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src1_div1; // fastdiv values for ne1 - struct fastdiv_values src1_div2; // fastdiv values for ne2 - struct fastdiv_values src1_div3; // fastdiv values for ne3 - struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src3_div1; // fastdiv values for ne1 - struct fastdiv_values src3_div2; // fastdiv values for ne2 - struct fastdiv_values src3_div3; // fastdiv values for ne3 - struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values broadcast_rk2; - struct fastdiv_values broadcast_rk3; - struct fastdiv_values broadcast_rv2; - struct fastdiv_values broadcast_rv3; - - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 - struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 - - struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 - struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index c360abe8d..966e16da6 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -1283,6 +1283,515 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } + +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; + HVX_Vector v2_3 = vptr[1]; + HVX_Vector v4_5 = vptr[2]; + HVX_Vector v6_7 = vptr[3]; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} +// q8_1x4x2 is similar to q8_0x4x2 but adds precomputed sum for q4_1 dot product optimization +static inline void quantize_block_f32_q8_1x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d, uint8_t * restrict y_s) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + + // Load and convert into QF32 + HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + // Compute max and scale + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); + + // Replicate first fp16 scale across all lanes + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; + vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + *(HVX_UVector *) y_d = vd_hf; + + // Divide input by the scale + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // Compute sum of quants + HVX_Vector vsum_i16 = hvx_vec_reduce_sum_i8(vx_i8); // 32x int16 sums (one per 32-elem block, replicated 4x) + HVX_Vector vsum_hf = Q6_Vhf_equals_Vh(vsum_i16); + + // Scale the sum by d (s = d * sum(qs)) + HVX_Vector vs_qf16 = Q6_Vqf16_vmpy_VhfVhf(vd_hf, vsum_hf); + HVX_Vector vs_hf = Q6_Vhf_equals_Vqf16(vs_qf16); + + *(HVX_UVector *) y_s = vs_hf; +} +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + const uint32_t qk = QK_Q8_1x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + // quants + d + s + return hex_round_up(ne + nb * 8 * sizeof(__fp16) + nb * 8 * sizeof(__fp16), 128); +} + +// Overrides input x +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_1x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t qblk_size = QK_Q8_1x4x2; // int8 + + const uint32_t drow_size = nb * 8 * 2; + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales + uint8_t * restrict y_s = (y + qrow_size + drow_size);// then sums + + // We use stack buffers for temporary scale/sum storage to avoid corrupting input x + // which shares the same VTCM buffer in the caller. + HVX_Vector tmp_d, tmp_s; + uint8_t * t_d = (uint8_t *)&tmp_d; + uint8_t * t_s = (uint8_t *)&tmp_s; + + for (uint32_t i = 0; i < nb; i++) { + // Block 0 (first 128 elements) + quantize_block_f32_q8_1x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d, t_s); + // Copy 4 fp16 scales (8 bytes) and 4 fp16 sums (8 bytes) to destination + // We can't use hvx_copy here efficiently for small sizes, so just memcpy or cast + *((uint64_t *)(y_d + (i*8 + 0)*2)) = *((uint64_t *)t_d); + *((uint64_t *)(y_s + (i*8 + 0)*2)) = *((uint64_t *)t_s); + + // Block 1 (next 128 elements) + quantize_block_f32_q8_1x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d, t_s); + *((uint64_t *)(y_d + (i*8 + 4)*2)) = *((uint64_t *)t_d); + *((uint64_t *)(y_s + (i*8 + 4)*2)) = *((uint64_t *)t_s); + } + +// Overrides input x +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8_1x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_1x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_drow_size = (n / 32) * 2; // fp16 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_drow_size = (n / 32) * 2; // fp16 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_m = r0_x_d + x_drow_size; // then mins + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_s = y_d + y_drow_size; // then sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + + // sum(q * y_q) + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + + // scales: d * dy + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + // mins: m * sy (sy = dy * sum(y_q)) + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_ms, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_ms, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = (n / 32) * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = (n / 32) * 2; + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r0_x_m = r0_x_d + x_drow_size; + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r1_x_m = r1_x_d + x_drow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y_s = y_d + y_drow_size; + + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_ms, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_ms, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_ms, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_ms, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = (n / 32) * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = (n / 32) * 2; + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r0_x_m = r0_x_d + x_drow_size; + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r1_x_m = r1_x_d + x_drow_size; + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; + const uint8_t * restrict y0_s = y0_d + y_drow_size; + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; + const uint8_t * restrict y1_s = y1_d + y_drow_size; + + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_s + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + // Accumulate + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd), r0_c0_sum)); + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_ms, r0_c0_sum)); + + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd), r0_c1_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_ms, r0_c1_sum)); + + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd), r1_c0_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_ms, r1_c0_sum)); + + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd), r1_c1_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_ms, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_s + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_s + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_ms, r0_c0_sum)); + + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_ms, r0_c1_sum)); + + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_ms, r1_c0_sum)); + + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_ms, r1_c1_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -2382,6 +2891,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8_1x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8_1x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8_1x4x2_2x2; + return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; @@ -2415,7 +2930,9 @@ static void htp_mminit_spad(struct htp_ops_context * octx, } // src0 spad is also used in dynamic quantizer to store padded src1 rows - size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + // Q8_1 row size is larger than Q8_0 so we use Q8_1 constant to be safe + // We pad to 1024 bytes (4 * 256) to allow safe 4x unrolled vector loading in dot kernels + size_t src1_row_size_padded = hex_round_up(src1_row_size + 1024, QK_Q8_1x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2477,7 +2994,9 @@ int op_matmul(struct htp_ops_context * octx) { octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + // We add 1024 bytes of padding at the end of src1_spad to allow safe speculative + // loading of 4x vectors (1024 bytes) in the dot product kernels, even for the last block. + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows + 1024, 256); octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -2518,8 +3037,13 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 943ca5c95..aa6a6c900 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -21,6 +22,9 @@ #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_SPAD_NROWS 16 +#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) + #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -42,7 +46,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct rope_th_ctx { +struct htp_rope_context { int32_t n_dims; int32_t mode; int32_t n_ctx_orig; @@ -57,7 +61,19 @@ struct rope_th_ctx { float theta_scale; float corr_dims[2]; + uint32_t src0_nrows_per_thread; + size_t spad_stride; + struct htp_ops_context * octx; + + size_t src0_row_size; + size_t dst_row_size; + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + size_t theta_cache_offset; + uint32_t src0_nrows; + + uint64_t t_start; }; static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims, dims[1] = MIN(n_dims - 1, end); } -static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { - memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); - - const int32_t * op_params = &octx->op_params[0]; - - rope_ctx->n_dims = ((const int32_t *) op_params)[1]; - rope_ctx->mode = ((const int32_t *) op_params)[2]; - rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; - - memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); - memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); - memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); - memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); - memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); - memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); - memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); - - rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); - - rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, - rope_ctx->beta_slow, rope_ctx->corr_dims); - - rope_ctx->octx = octx; - FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, - rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); -} +static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; -static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; + uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 - //const float x0 = src[0]; - //const float x1 = src[num_elems/2]; + uint32_t he = ne / 2; // half_dims offset in elements + uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i += 2) { + HVX_Vector v0 = vsrc[i/2+0]; + HVX_Vector v1 = vsrc[i/2+hv]; - //src += 1; - //dst += 1; - // } - - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - int half_size = (sizeof(float) * (num_elems / 2)); - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); - *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); + vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + } - src0_curr += VLEN; - theta_curr += 2 * VLEN; - dst_curr += VLEN; + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i/2]; + float x1 = src0[i/2 + he]; + dst[i/2] = x0 * cos_theta - x1 * sin_theta; + dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; } } -static void hvx_calc_rope_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; - - //const float x0 = src[0]; - //const float x1 = src[1]; +static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[1] = x0*sin_theta + x1*cos_theta; + uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two - //src += 2; - //dst += 2; - // } + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i+=2) { + HVX_Vector v0 = vsrc[i+0]; + HVX_Vector v1 = vsrc[i+1]; - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -239,151 +203,182 @@ static void hvx_calc_rope_f32(const float * restrict src0, HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); - *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); + vdst[i+0] = Q6_V_lo_W(vstore); + vdst[i+1] = Q6_V_hi_W(vstore); + } + + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i+0]; + float x1 = src0[i+1]; + dst[i+0] = x0 * cos_theta - x1 * sin_theta; + dst[i+1] = x0 * sin_theta + x1 * cos_theta; + } +} + +static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); + + hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } + } +} + +static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); - src0_curr += 2 * VLEN; - theta_curr += 2 * VLEN; - dst_curr += 2 * VLEN; + hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } } } -static void rope_hex_f32(struct rope_th_ctx * rope_ctx, - const uint32_t ir0, - const uint32_t ir1, - int nth, - int ith, - const int opt_path) { - struct htp_ops_context * octx = rope_ctx->octx; +static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_rope_context * rctx = (struct htp_rope_context *) data; + struct htp_ops_context * octx = rctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; - htp_rope_preamble; - const int32_t * pos = (const int32_t *) src1->data; + const uint32_t src0_nrows = rctx->src0_nrows; + const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread; - float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const float * freq_factors = NULL; - if (src2 != NULL) { - freq_factors = (const float *) src2->data; + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; } - const uint32_t i1_end = MIN(ir1, ne1); - const int32_t half_dims = rope_ctx->n_dims / 2; - const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - const int32_t p = pos[i2]; + uint64_t tt = HAP_perf_get_qtimer_count(); - rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, - rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); + const int32_t mode = rctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; - for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads - const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); - float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); + // VTCM setup + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + float * theta_cache = (float *) (src0_spad_base); + src0_spad_base = src0_spad_base + rctx->theta_cache_offset; + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - const float * src_loc = src; - float * dst_data_loc = dst_data; + dma_queue * dma_queue = octx->ctx->dma[ith]; + const int32_t * pos = (const int32_t *) src1->data; + const float * freq_factors = src2->data ? (const float *) src2->data : NULL; - if (1 == opt_path) { - if (is_neox) { - hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } else { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } + uint32_t ir = 0; + uint32_t prev_i2 = (uint32_t) -1; - src_loc += rope_ctx->n_dims; - dst_data_loc += rope_ctx->n_dims; - } else { - for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { - const float cos_theta = wp0[i0 + 0]; - const float sin_theta = wp0[i0 + 1]; + for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch + for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len + for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads + if (ir < src0_start_row) { ir++; i1++; continue; } + if (ir >= src0_end_row) goto done; - if (is_neox) { - const float x0 = src_loc[0]; - const float x1 = src_loc[half_dims]; + // Rows in this block + const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1); - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; + // Depth before prefetch + uint32_t dma_depth = dma_queue_depth(dma_queue); - src_loc += 1; - dst_data_loc += 1; - } else { - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; + // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + // Prefetch loop + for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) { + pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK); - src_loc += 2; - dst_data_loc += 2; - } - } + uint32_t pi1 = i1 + pr; + uint32_t pir = ir + pr; + + // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0); - src_loc += (is_neox ? half_dims : 0); - dst_data_loc += (is_neox ? half_dims : 0); + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); } - // TODO: use simd to speed up the remaining elements copy - memcpy(dst_data_loc, src_loc, remain_bytes); - } - } - } -} + // Update theta cache + if (i2 != prev_i2) { + prev_i2 = i2; -static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { - struct htp_ops_context * octx = rope_ctx->octx; + const int32_t p = pos[i2]; + rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + } - htp_rope_preamble; + // Skip DMA transactions from prev block (if any) + // No need to wait for these since the DMA is setup for in-order processing + for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + // Compute loop + for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) { + // Number of rows to compute + cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK); - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src; + uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst; - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; - } + // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + if (is_neox) { + rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } else { + rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr); - rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); + // Prefetch more rows (if any) + if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) { + uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK); + uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS; + uint32_t pir = ir + HTP_ROPE_SPAD_NROWS; - t2 = HAP_perf_get_qtimer_count(); + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); - FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + } + } + } + } -static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; +done: + dma_queue_flush(dma_queue); + tt = HAP_perf_get_qtimer_count() - tt; - rope_job_f32_per_thread(rope_ctx, n, i); + FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt)); } static int execute_op_rope_f32(struct htp_ops_context * octx) { @@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct rope_th_ctx rope_ctx; + const char * op_type = "rope-f32"; switch (octx->op) { case HTP_OP_ROPE: - op_func = rope_job_dispatcher_f32; - op_type = "rope-f32"; - - init_rope_ctx(&rope_ctx, octx); break; default: @@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const uint32_t n_threads = octx->n_threads; const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; - - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - - if (src2->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " - "dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); - } else { - FARF(HIGH, - "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); - } + // Aligned row sizes for VTCM + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + + // Calculate spad sizes per thread + size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; + size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned; + size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread; - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + // Check if we fit in VTCM + size_t total_vtcm_needed = spad_per_thread * n_threads; + if (octx->ctx->vtcm_size < total_vtcm_needed) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed); return HTP_STATUS_VTCM_TOO_SMALL; } + // Assign sizes + octx->src0_spad.size_per_thread = src0_spad_per_thread; + octx->dst_spad.size_per_thread = dst_spad_per_thread; + octx->src0_spad.size = n_threads * src0_spad_per_thread; + octx->dst_spad.size = n_threads * dst_spad_per_thread; + octx->src1_spad.size = 0; + + // Assign pointers octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src1_spad.data = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + // Fill context + struct htp_rope_context rctx; + memset(&rctx, 0, sizeof(struct htp_rope_context)); + + rctx.t_start = HAP_perf_get_qtimer_count(); + + rctx.octx = octx; + + const int32_t * op_params = &octx->op_params[0]; + rctx.n_dims = ((const int32_t *) op_params)[1]; + rctx.mode = ((const int32_t *) op_params)[2]; + rctx.n_ctx_orig = ((const int32_t *) op_params)[4]; + memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims); + + rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims); + + rctx.src0_row_size = src0_row_size; + rctx.dst_row_size = dst_row_size; + rctx.src0_row_size_aligned = src0_row_size_aligned; + rctx.dst_row_size_aligned = dst_row_size_aligned; + rctx.theta_cache_offset = theta_cache_size_aligned; + + uint32_t ne0 = dst->ne[0]; uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + rctx.src0_nrows = src0_nrows; + + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, + rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + uint32_t n_jobs = MIN(n_threads, src0_nrows); + rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 904484da9..2fd6c9077 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -43,11 +43,21 @@ \ const uint32_t nr = ne01; -static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +struct htp_set_rows_context { + struct htp_ops_context * octx; + struct fastdiv_values div_ne12; + struct fastdiv_values div_ne11; + uint32_t src0_nrows_per_thread; +}; + +static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; } -static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; -} - -static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); -} - -static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_set_rows(struct htp_ops_context * octx) { @@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->set_rows_div_ne12 = init_fastdiv_values(ne12); - octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + struct htp_set_rows_context srctx; + srctx.octx = octx; + srctx.div_ne12 = init_fastdiv_values(ne12); + srctx.div_ne11 = init_fastdiv_values(ne11); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index e91a16d94..58278079f 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -48,7 +49,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct softmax_th_ctx { +struct htp_softmax_context { bool use_f16; bool use_src1; uint32_t n_head; @@ -59,28 +60,48 @@ struct softmax_th_ctx { float m0; float m1; + uint32_t src0_nrows_per_thread; + struct fastdiv_values fastdiv_ne01; + struct fastdiv_values fastdiv_ne02; + struct fastdiv_values fastdiv_ne12; // For mask broadcasting + struct fastdiv_values fastdiv_ne13; // For mask broadcasting + size_t spad_stride; + struct htp_ops_context * octx; }; -static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { +static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; - memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + memset(smctx, 0, sizeof(struct htp_softmax_context)); + + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + + smctx->n_head = src0->ne[2]; + smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head)); + + smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); + smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); - memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + smctx->use_src1 = (src1->ne[0] != 0); + smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); - softmax_ctx->n_head = src0->ne[2]; - softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + smctx->octx = octx; - softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); - softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + // Initialize fastdiv values + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; - softmax_ctx->use_src1 = (src1->ne[0] != 0); - softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); + if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - softmax_ctx->octx = octx; + const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; + const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + + if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); + if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); } static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, @@ -183,83 +204,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src, return sum; } -static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { - struct htp_ops_context * octx = softmax_ctx->octx; - - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * dst = &octx->dst; - - htp_softmax_preamble3; - - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); - - float * wp0 = (float *) src0_spad_data; - float * wp1 = (float *) src1_spad_data; - float * wp2 = (float *) dst_spad_data; - - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { - const uint32_t i11 = i01; - const uint32_t i12 = i02 % ne12; - const uint32_t i13 = i03 % ne13; - - // ALiBi - const uint32_t h = i02; // head - - const float slope = (softmax_ctx->max_bias > 0.0f) ? - h < softmax_ctx->n_head_log2 ? - powf(softmax_ctx->m0, h + 1) : - powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : - 1.0f; - - float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - // broadcast the mask across rows - __fp16 * mp_f16 = (softmax_ctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (softmax_ctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - - if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, - (const uint8_t *) mp_f32, slope); - } else { - hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); - if (mp_f32) { - if (softmax_ctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { - hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); - } else { - float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); - float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); - sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); - } - } - } - } -} - -static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { - struct htp_ops_context * octx = softmax_ctx->octx; +static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; + struct htp_ops_context * octx = smctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; @@ -268,7 +215,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int htp_softmax_preamble3; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -291,20 +238,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int opt_path = 1; } - softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; + + uint32_t prev_i2 = (uint32_t)-1; + float slope = 1.0f; + + for (uint32_t r = src0_start_row; r < src0_end_row; ++r) { + uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01); + uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01); + uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02); + uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02); + + // Map to original logic indices + // i01 = i1 + // i02 = i2 + // i03 = i3 + + const uint32_t i11 = i1; + // const uint32_t i12 = i2 % ne12; + // const uint32_t i13 = i3 % ne13; + + uint32_t i12, i13; + if (ne12 == ne02) { + i12 = i2; + } else { + i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12); + } + + if (ne13 == ne03) { + i13 = i3; + } else { + i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13); + } + + // ALiBi + if (i2 != prev_i2) { + const uint32_t h = i2; // head + + slope = (smctx->max_bias > 0.0f) ? + h < smctx->n_head_log2 ? + powf(smctx->m0, h + 1) : + powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : + 1.0f; + prev_i2 = i2; + } + + float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (smctx->use_src1) ? + (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + float * mp_f32 = (smctx->use_src1) ? + (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + + if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, + (const uint8_t *) mp_f32, slope); + } else { + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + if (mp_f32) { + if (smctx->use_f16) { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } + } + } + + if (1 == opt_path) { + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); + } + } t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { - struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; - softmax_job_f32_per_thread(p_softmax_ctx, n, i); -} - static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -312,17 +342,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct softmax_th_ctx softmax_ctx; + struct htp_softmax_context smctx; + const char * op_type = "softmax-f32"; switch (octx->op) { case HTP_OP_SOFTMAX: - op_func = softmax_job_dispatcher_f32; - op_type = "softmax-f32"; - - init_softmax_ctx(&softmax_ctx, octx); + init_softmax_ctx(&smctx, octx); break; default: @@ -342,6 +367,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Use stride for calculating offset + smctx.spad_stride = hex_round_up(src0_row_size, 128); + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; if (src1->ne[0]) { @@ -371,8 +399,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 62e45da2b..04fa72182 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -17,7 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" - #define sum_rows_preamble \ struct htp_tensor *src0 = &octx->src0;\ struct htp_tensor *dst = &octx->dst; \ @@ -42,53 +41,54 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; \ -static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { - sum_rows_preamble; +struct sum_rows_context { + const uint8_t * src_data; + uint8_t * dst_data; + uint32_t ne00; + size_t src_stride; + size_t dst_stride; + uint32_t rows_per_thread; + uint32_t total_rows; + bool opt_path; +}; - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; +static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) { + const struct sum_rows_context * smctx = (const struct sum_rows_context *) data; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t rows_per_thread = smctx->rows_per_thread; + const uint32_t total_rows = smctx->total_rows; - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t start_row = rows_per_thread * ith; + const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); - // no work for this thread - if (src0_start_row >= src0_end_row) { - return HTP_STATUS_OK; + if (start_row >= end_row) { + return; } - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + const size_t src_stride = smctx->src_stride; + const size_t dst_stride = smctx->dst_stride; + const uint32_t ne00 = smctx->ne00; + const bool opt_path = smctx->opt_path; - const uint8_t * restrict data_src = (const uint8_t *) src0->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride)); + float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride)); - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + // Calculate actual number of rows for this thread + const uint32_t n_rows = end_row - start_row; - for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { - const float * restrict src_local = src_th + (ir * ne00); + for (uint32_t ir = 0; ir < n_rows; ir++) { + const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float))); - if (ir + 1 < src0_nrows_per_thread) { - hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + if (ir + 1 < n_rows) { + hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1); } - if (1 == opt_path) { + if (opt_path) { dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); } else { dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); } } - - return HTP_STATUS_OK; -} - -static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { - sum_rows_thread_f32((struct htp_ops_context *) data, n, i); } int op_sum_rows(struct htp_ops_context * octx) { @@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01 * ne02 * ne03; uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + bool opt_path = false; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = true; + } + + struct sum_rows_context smctx = { + .src_data = (const uint8_t *) src0->data, + .dst_data = (uint8_t *) dst->data, + .ne00 = ne00, + .src_stride = nb01, + .dst_stride = nb1, + .rows_per_thread = rows_per_thread, + .total_rows = src0_nrows, + .opt_path = opt_path, + }; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); return HTP_STATUS_OK; } - diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index ce879bf03..8cc1f582f 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,6 +17,28 @@ #include "htp-msg.h" #include "htp-ops.h" +struct htp_unary_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + uint8_t * data_dst; + + size_t src0_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + uint32_t nc; +}; + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -75,128 +97,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } -static void scale_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void scale_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float scale = 0.f; float bias = 0.f; memcpy(&scale, &op_params[0], sizeof(float)); memcpy(&bias, &op_params[1], sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); - - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); } } -static void rms_norm_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void rms_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float epsilon = 0.f; memcpy(&epsilon, op_params, sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); - } else { - float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); - - const float mean = sum / row_elems; - const float scale = 1.0f / sqrtf(mean + epsilon); - - hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); - } + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); } } -static void sqr_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqr_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); - - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (1 == opt_path) { - hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void sqrt_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqrt_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void unary_job_f32_per_thread(const struct htp_tensor * src, - struct htp_tensor * dst, - uint8_t * spad, - int htp_op, - int32_t * op_params, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread) { +static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; + struct htp_ops_context * octx = uctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + htp_unary_preamble; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; + int htp_op = octx->op; + int32_t * op_params = octx->op_params; + uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const size_t src0_row_size = uctx->src0_row_size; + const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; + const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; + + const uint32_t src0_nrows = uctx->src0_nrows; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -208,79 +197,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const uint8_t * restrict data_src = uctx->data_src0; + uint8_t * restrict data_dst = uctx->data_dst; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t dst_spad_half_size = uctx->dst_spad_half_size; + + const int BLOCK = uctx->block; + if (BLOCK == 0) { + FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + octx->src0_spad.size_per_thread, src0_row_size_aligned); + return; } - const uint8_t * restrict data_src = (const uint8_t *) src->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + dma_queue * dma_queue = octx->ctx->dma[ith]; - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); - uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - switch (htp_op) { - case HTP_OP_RMS_NORM: - rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SCALE: - scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQR: - sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQRT: - sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); - default: - break; + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); } + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + // Process block in VTCM + switch (htp_op) { + case HTP_OP_RMS_NORM: + rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SCALE: + scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQR: + sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQRT: + sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + default: + break; + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - - unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, - octx->src0_nrows_per_thread); -} - static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; const struct htp_tensor * src0 = &octx->src0; struct htp_tensor * dst = &octx->dst; - worker_callback_t unary_op_func; - const char * op_type = NULL; + const char * op_type = NULL; switch (octx->op) { case HTP_OP_RMS_NORM: - unary_op_func = unary_job_dispatcher_f32; - op_type = "rmsnorm-f32"; + op_type = "rmsnorm-f32"; break; case HTP_OP_SCALE: - unary_op_func = unary_job_dispatcher_f32; - op_type = "scale-f32"; + op_type = "scale-f32"; break; case HTP_OP_SQR: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqr-f32"; + op_type = "sqr-f32"; break; case HTP_OP_SQRT: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqrt-f32"; + op_type = "sqrt-f32"; break; default: @@ -294,32 +308,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + // Double buffering requires 2x size per buffer - FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { + if (vtcm_row_per_thread == 0) { FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; } + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + struct htp_unary_context uctx = { + .octx = octx, + .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows = src0_nrows, + + .data_src0 = (const uint8_t *)src0->data, + .data_dst = (uint8_t *)dst->data, + + .src0_row_size = src0_row_size, + .dst_row_size = dst_row_size, + + .src0_row_size_aligned = src0_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + + .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, + + .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, + .nc = src0->ne[0], + }; - worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); } return err;