Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/random_sample_batched.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
Expand Down
16 changes: 16 additions & 0 deletions include/infinicore/ops/flash_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class FlashAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool);
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
} // namespace infinicore::op
23 changes: 23 additions & 0 deletions include/infinicore/ops/kv_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class KVCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache,
Tensor v_cache,
Tensor k,
Tensor v,
Tensor past_kv_lengths);
static common::OpDispatcher<schema> &dispatcher();
};

void kv_caching_(Tensor k_cache,
Tensor v_cache,
Tensor k,
Tensor v,
Tensor past_kv_lengths);
} // namespace infinicore::op
20 changes: 20 additions & 0 deletions include/infinicore/ops/random_sample_batched.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class RandomSampleBatched {
public:
using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int);
static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
static common::OpDispatcher<schema> &dispatcher();
};

// Out-of-place API
Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
// In-place API
void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);

} // namespace infinicore::op
3 changes: 3 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
Expand All @@ -20,6 +22,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/random_sample_batched.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
Expand Down
34 changes: 34 additions & 0 deletions include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
float scale,
char is_causal);

__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc);
#endif
31 changes: 31 additions & 0 deletions include/infiniop/ops/kv_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef __INFINIOP_KV_CACHING_API_H__
#define __INFINIOP_KV_CACHING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;

__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths);

__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream);

__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);

#endif
6 changes: 0 additions & 6 deletions include/infiniop/ops/random_sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopRandomSample(
infiniopRandomSampleDescriptor_t desc,
void *workspace,
Expand Down
34 changes: 34 additions & 0 deletions include/infiniop/ops/random_sample_batched.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__
#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t;

__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleBatchedDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize(
infiniopRandomSampleBatchedDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopRandomSampleBatched(
infiniopRandomSampleBatchedDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
const float *random_val,
const float *topp,
const int *topk,
const float *temperature,
int batch_size,
void *stream);

__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor(
infiniopRandomSampleBatchedDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
Expand Down Expand Up @@ -115,6 +116,7 @@
"add_rms_norm",
"add_rms_norm_",
"attention",
"kv_caching",
"matmul",
"mul",
"narrow",
Expand Down
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from .random_sample import random_sample
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
from .scaled_dot_product_attention import scaled_dot_product_attention
from .silu import silu
from .swiglu import swiglu

__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"scaled_dot_product_attention",
"silu",
"swiglu",
"linear",
Expand Down
28 changes: 28 additions & 0 deletions python/infinicore/nn/functional/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import math

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]

if scale is None:
scale = 1 / math.sqrt(emb_dim)

return Tensor(
_infinicore.flash_attention(
query._underlying, key._underlying, value._underlying, scale, is_causal
)
)
13 changes: 13 additions & 0 deletions python/infinicore/ops/kv_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from infinicore.lib import _infinicore


def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
_infinicore.kv_caching_(
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
past_kv_lengths._underlying,
)

return k_cache, v_cache
33 changes: 25 additions & 8 deletions scripts/build_ntops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import importlib
import pathlib

Expand All @@ -11,16 +12,32 @@
def _find_and_build_ops():
ops_path = SRC_DIR_PATH / "infiniop" / "ops"

for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = []

if ninetoothed_path.is_dir():
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)
for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"

module.build()
if not ninetoothed_path.is_dir():
continue

build_file = ninetoothed_path / "build.py"
if not build_file.exists():
continue

futures.append(executor.submit(_build, ninetoothed_path))

for future in concurrent.futures.as_completed(futures):
future.result()


def _build(ninetoothed_path):
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)

module.build()


if __name__ == "__main__":
Expand Down
29 changes: 29 additions & 0 deletions src/infinicore/ops/flash_attention/flash_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "infinicore/ops/flash_attention.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<FlashAttention::schema> &FlashAttention::dispatcher() {
static common::OpDispatcher<FlashAttention::schema> dispatcher_;
return dispatcher_;
};

void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(
out, q, k, v, scale, is_causal);
}

Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal) {
Shape shape = q->shape();
auto out = Tensor::empty(shape, q->dtype(), q->device());
flash_attention_(out, q, k, v, scale, is_causal);
return out;
}

void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) {
FlashAttention::execute(out, q, k, v, scale, is_causal);
}
} // namespace infinicore::op
Loading