|
@@ -13,7 +13,7 @@
|
|
|
|
|
|
namespace MNN {
|
|
|
|
|
|
-// @brief Translate an address to a hex number string
|
|
|
+// Translate an address to a hex number string
|
|
|
static inline std::string addrToHex(void *addr) {
|
|
|
std::string result = "";
|
|
|
uint64_t n = (uint64_t)addr;
|
|
@@ -106,11 +106,27 @@ void KVCacheManager::unmapKVCache(size_t keySize, size_t valueSize)
|
|
|
*/
|
|
|
void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|
|
/*=================================== Key ===================================*/
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8});
|
|
|
+ mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
|
|
+ for (int h = 0; h < mKvNumHead; h++) {
|
|
|
+ memcpy(
|
|
|
+ new_key->host<char>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
|
|
+ );
|
|
|
+ }
|
|
|
+ mPastKey.reset(new_key);
|
|
|
+ }
|
|
|
+ else if (mConfig.mQuantKey) {
|
|
|
auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
|
|
|
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP);
|
|
|
+ memcpy(
|
|
|
+ new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
|
|
+ );
|
|
|
}
|
|
|
mPastKey.reset(new_key);
|
|
|
}
|
|
@@ -118,7 +134,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|
|
auto new_key = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
|
|
|
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
mPastKey.reset(new_key);
|
|
|
}
|
|
@@ -128,7 +148,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|
|
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP);
|
|
|
+ memcpy(
|
|
|
+ new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
|
|
+ mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
|
|
+ oldMaxLength * hP
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
mPastValue.reset(new_value);
|
|
@@ -138,7 +162,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|
|
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
|
|
+ mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
|
|
+ oldMaxLength * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
mPastValue.reset(new_value);
|
|
@@ -151,16 +179,35 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|
|
*/
|
|
|
void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|
|
/*=================================== Key ===================================*/
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ for (int h = 0; h < mKvNumHead; h++) {
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
|
|
+ );
|
|
|
+ }
|
|
|
+ mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
|
|
|
+ mPastKey.reset();
|
|
|
+ }
|
|
|
if (mConfig.mQuantKey) {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP);
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
|
|
+ );
|
|
|
}
|
|
|
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
|
|
|
mPastKey.reset();
|
|
|
}
|
|
|
else {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
|
|
|
mPastKey.reset();
|
|
@@ -169,7 +216,11 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|
|
if (mConfig.mQuantValue) {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP);
|
|
|
+ memcpy(
|
|
|
+ mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
|
|
+ mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
|
|
+ oldMaxLength * hP
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC);
|
|
@@ -178,7 +229,11 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|
|
else {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
|
|
+ mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
|
|
+ oldMaxLength * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC);
|
|
@@ -189,14 +244,12 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|
|
/*
|
|
|
** @brief Expand the size of kvcache files in disk
|
|
|
*/
|
|
|
-void KVCacheManager::expandKVCacheInDisk(int oldMaxLength) {
|
|
|
- size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
- size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
+void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize) {
|
|
|
// Step 1: Copy the old kvcache from files to temporary buffers in memory
|
|
|
std::shared_ptr<Tensor> old_key, old_value;
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
|
|
|
} else {
|
|
|
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
|
|
@@ -216,25 +269,49 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength) {
|
|
|
resetKVCacheFileSize(keySize, valueSize);
|
|
|
mmapKVCache(keySize, valueSize);
|
|
|
// Step 3: Move the kvcache from temporary buffers in memory to disk
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ for (int h = 0; h < mKvNumHead; h++) {
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ old_key->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
|
|
+ UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
|
|
+ );
|
|
|
+ }
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP);
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
|
|
+ old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
|
|
+ );
|
|
|
}
|
|
|
} else {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
|
|
+ UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
if (mConfig.mQuantValue) {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP);
|
|
|
+ memcpy(
|
|
|
+ mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
|
|
+ old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
|
|
+ oldMaxLength * hP
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
|
|
- memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes);
|
|
|
+ memcpy(
|
|
|
+ mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
|
|
+ old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
|
|
+ oldMaxLength * hP * mBytes
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -253,12 +330,22 @@ void KVCacheManager::onResize(int kv_num_head, int head_dim) {
|
|
|
if (mThreadNum > mKvNumHead) {
|
|
|
mThreadNum = mKvNumHead;
|
|
|
}
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ static_cast<CPUBackend *>(mBackend)->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void KVCacheManager::onAlloc(int kv_seq_len) {
|
|
|
mMaxLength = kv_seq_len + mConfig.mExpandChunk;
|
|
|
- size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
+ size_t keySize = 0, valueSize = 0;
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
|
|
+ } else {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
|
|
+ }
|
|
|
+ valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
/*============== Put the kvcache in disk ===========*/
|
|
|
if (mConfig.mKVCacheSizeLimit != -1 && keySize + valueSize > mConfig.mKVCacheSizeLimit) {
|
|
|
createKVCacheFile();
|
|
@@ -268,7 +355,9 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
|
|
|
}
|
|
|
/*============== Put the kvcache in memory ===========*/
|
|
|
else {
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
|
|
} else {
|
|
|
mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
|
@@ -278,15 +367,22 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
|
|
|
} else {
|
|
|
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
|
|
|
}
|
|
|
- mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
|
|
|
- mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
|
|
|
- }
|
|
|
- /* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
- mDequantKeyScale.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}));
|
|
|
- mDequantKeyZeroPoint.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}));
|
|
|
- mBackend->onAcquireBuffer(mDequantKeyScale.get(), Backend::STATIC);
|
|
|
- mBackend->onAcquireBuffer(mDequantKeyZeroPoint.get(), Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
|
|
|
+ }
|
|
|
+ // scale, zero point and sum of key for quantization
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ mKeyScale.reset(Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}));
|
|
|
+ mKeyZeroPoint.reset(Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}));
|
|
|
+ mKeySum.reset(Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}));
|
|
|
+ mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(mKeySum.get(), Backend::STATIC);
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
+ mKeyScale.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), hP}));
|
|
|
+ mKeyZeroPoint.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), hP}));
|
|
|
+ mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -296,10 +392,19 @@ void KVCacheManager::onRealloc(int kv_seq_len) {
|
|
|
}
|
|
|
int oldMaxLength = mMaxLength;
|
|
|
mMaxLength = kv_seq_len + mConfig.mExpandChunk;
|
|
|
- size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
- size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
+ size_t oldKeySize, oldValueSize, keySize, valueSize;
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
+ oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP;
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
|
|
+ } else {
|
|
|
+ oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes;
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
|
|
+ }
|
|
|
+ oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
+ valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
/*==== No limit for kvcache ====*/
|
|
|
if (mConfig.mKVCacheSizeLimit == -1) {
|
|
|
expandKVCacheInMem(oldMaxLength);
|
|
@@ -318,51 +423,100 @@ void KVCacheManager::onRealloc(int kv_seq_len) {
|
|
|
}
|
|
|
/*==== Last time the kvcache is disk, now it should be in disk too ====*/
|
|
|
else {
|
|
|
- expandKVCacheInDisk(oldMaxLength);
|
|
|
+ expandKVCacheInDisk(oldMaxLength, oldKeySize, oldValueSize, keySize, valueSize);
|
|
|
}
|
|
|
/* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */
|
|
|
- if (mConfig.mQuantKey) {
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ auto new_scale = Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8});
|
|
|
+ auto new_zeroPoint = Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8});
|
|
|
+ auto new_sum = Tensor::createDevice<int32_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8});
|
|
|
+ mBackend->onAcquireBuffer(new_scale, Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
|
|
|
+ mBackend->onAcquireBuffer(new_sum, Backend::STATIC);
|
|
|
+ for (int h = 0; h < mKvNumHead; h++) {
|
|
|
+ memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
|
|
+ memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
|
|
+ memcpy(new_sum->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
|
|
+ }
|
|
|
+ mKeyScale.reset(new_scale);
|
|
|
+ mKeyZeroPoint.reset(new_zeroPoint);
|
|
|
+ mKeySum.reset(new_sum);
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
auto new_scale = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP});
|
|
|
auto new_zeroPoint = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP});
|
|
|
mBackend->onAcquireBuffer(new_scale, Backend::STATIC);
|
|
|
mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
|
|
|
for (int h = 0; h < mKvNumHead; h++) {
|
|
|
- memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mDequantKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
|
|
- memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mDequantKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
|
|
+ memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
|
|
+ memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
|
|
}
|
|
|
- mDequantKeyScale.reset(new_scale);
|
|
|
- mDequantKeyZeroPoint.reset(new_zeroPoint);
|
|
|
+ mKeyScale.reset(new_scale);
|
|
|
+ mKeyZeroPoint.reset(new_zeroPoint);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
void KVCacheManager::onClear() {
|
|
|
if (mKVCacheInDisk) {
|
|
|
- size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes);
|
|
|
- size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
- unmapKVCache(oldKeySize, oldValueSize);
|
|
|
+ size_t keySize = 0, valueSize = 0;
|
|
|
+ if (mConfig.mUseInt8Kernel) {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
|
|
+ } else if (mConfig.mQuantKey) {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
|
|
+ } else {
|
|
|
+ keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
|
|
+ }
|
|
|
+ valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
|
|
+ unmapKVCache(keySize, valueSize);
|
|
|
removeKVCacheFile();
|
|
|
mKVCacheInDisk = false;
|
|
|
}
|
|
|
- else {
|
|
|
- mPastKey.reset();
|
|
|
- mPastValue.reset();
|
|
|
- }
|
|
|
+ mPastKey.reset();
|
|
|
+ mPastValue.reset();
|
|
|
+ mKeyScale.reset();
|
|
|
+ mKeyZeroPoint.reset();
|
|
|
+ mKeySum.reset();
|
|
|
mMaxLength = mPastLength = 0;
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
-static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim,
|
|
|
- int hP, int kv_h, bool quantKey, char* scale, char* zero_point, const MNN::CoreFunctions * core) {
|
|
|
- if (quantKey) {
|
|
|
- int8_t * key_dst = reinterpret_cast<int8_t*>(pack_key);
|
|
|
- T * scale_dst = reinterpret_cast<T*>(scale);
|
|
|
- T * zeroPoint_dst = reinterpret_cast<T*>(zero_point);
|
|
|
+void KVCacheManager::pack_key(const Tensor* key, int seq_len, int kv_h) {
|
|
|
+ if (mConfig.mUseInt8Kernel) { // [maxlen/hP8, headdim/lP8, hP8, lP8]
|
|
|
+ int8_t * key_dst = reinterpret_cast<int8_t*>(addrOfKey(kv_h));
|
|
|
+ float * scale_dst = reinterpret_cast<float*>(addrOfScale(kv_h));
|
|
|
+ float * zeroPoint_dst = reinterpret_cast<float*>(addrOfZeroPoint(kv_h));
|
|
|
+ float * sum_dst = reinterpret_cast<float*>(addrOfKeySum(kv_h));
|
|
|
+ for (int s = 0; s < seq_len; s++) {
|
|
|
+ T * key_src = key->host<T>() + s * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
|
|
+ float minKey = key_src[0];
|
|
|
+ float maxKey = key_src[0];
|
|
|
+ float sumKey = key_src[0];
|
|
|
+ for (int d = 1; d < mHeadDim; d++) {
|
|
|
+ minKey = ALIMIN(minKey, key_src[d]);
|
|
|
+ maxKey = ALIMAX(maxKey, key_src[d]);
|
|
|
+ sumKey += key_src[d];
|
|
|
+ }
|
|
|
+ int out_index = (mPastLength + s) / hP8;
|
|
|
+ int in_index = (mPastLength + s) % hP8;
|
|
|
+ scale_dst[out_index * hP8 + in_index] = (maxKey - minKey) / 255.0f;
|
|
|
+ zeroPoint_dst[out_index * hP8 + in_index] = -255.0f * minKey / (maxKey - minKey) - 128.0;
|
|
|
+ sum_dst[out_index * hP8 + in_index] = sumKey;
|
|
|
+ for (int d = 0; d < mHeadDim; d++) {
|
|
|
+ int i = d / lP8;
|
|
|
+ int j = d % lP8;
|
|
|
+ key_dst[out_index * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + i * hP8 * lP8 + in_index * lP8 + j] = roundf((key_src[d] - minKey) / (maxKey - minKey) * 255.0f - 128.0f);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (mConfig.mQuantKey) { // [maxlen/hP, headdim, hP]
|
|
|
+ int8_t * key_dst = reinterpret_cast<int8_t*>(addrOfKey(kv_h));
|
|
|
+ T * scale_dst = reinterpret_cast<T*>(addrOfScale(kv_h));
|
|
|
+ T * zeroPoint_dst = reinterpret_cast<T*>(addrOfZeroPoint(kv_h));
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
T * key_src = key->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
|
|
int out_index = (mPastLength + i) / hP;
|
|
|
int in_index = (mPastLength + i) % hP;
|
|
|
T minKey, maxKey;
|
|
|
- core->MNNCountMaxMinValue((float*)key_src, (float*)&minKey, (float*)&maxKey, mHeadDim);
|
|
|
+ static_cast<CPUBackend*>(mBackend)->functions()->MNNCountMaxMinValue((float*)key_src, (float*)&minKey, (float*)&maxKey, mHeadDim);
|
|
|
scale_dst[out_index * hP + in_index] = (maxKey - minKey) / 255.0f;
|
|
|
zeroPoint_dst[out_index * hP + in_index] = 128.0f * (maxKey - minKey) / 255.0f + minKey;
|
|
|
for (int j = 0; j < mHeadDim; j++) {
|
|
@@ -370,8 +524,8 @@ static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- else {
|
|
|
- T * key_dst = reinterpret_cast<T*>(pack_key);
|
|
|
+ else { // [maxlen/hP, headdim, hP]
|
|
|
+ T * key_dst = reinterpret_cast<T*>(addrOfKey(kv_h));
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
T * key_src = key->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
|
|
int out_index = (mPastLength + i) / hP;
|
|
@@ -384,16 +538,16 @@ static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
-static void pack_value(const Tensor* value, char* pack_value, int mMaxLength, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim, int hP, int kv_h, bool quantValue, const MNN::CoreFunctions * core) {
|
|
|
- if (quantValue) {
|
|
|
- fp8_t * value_dst = reinterpret_cast<fp8_t*>(pack_value);
|
|
|
+void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { // [headdim/hP, maxlen, hP]
|
|
|
+ if (mConfig.mQuantValue) {
|
|
|
+ fp8_t * value_dst = reinterpret_cast<fp8_t*>(addrOfValue(kv_h));
|
|
|
uint8_t * buf = (uint8_t *)MNNMemoryAllocAlign(mHeadDim, MNN_MEMORY_ALIGN_DEFAULT);
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
T * value_src = value->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
|
|
if (sizeof(T) == 2) {
|
|
|
- core->MNNFp16ToFp8(buf, (uint16_t*)value_src, mHeadDim);
|
|
|
+ static_cast<CPUBackend*>(mBackend)->functions()->MNNFp16ToFp8(buf, (uint16_t*)value_src, mHeadDim);
|
|
|
} else {
|
|
|
- core->MNNFp32ToFp8(buf, (float*)value_src, mHeadDim);
|
|
|
+ static_cast<CPUBackend*>(mBackend)->functions()->MNNFp32ToFp8(buf, (float*)value_src, mHeadDim);
|
|
|
}
|
|
|
for (int j = 0; j < mHeadDim; j++) {
|
|
|
int out_index = j / hP;
|
|
@@ -404,7 +558,7 @@ static void pack_value(const Tensor* value, char* pack_value, int mMaxLength, in
|
|
|
MNNMemoryFreeAlign(buf);
|
|
|
}
|
|
|
else {
|
|
|
- T * value_dst = reinterpret_cast<T*>(pack_value);
|
|
|
+ T * value_dst = reinterpret_cast<T*>(addrOfValue(kv_h));
|
|
|
for (int i = 0; i < seq_len; i++) {
|
|
|
T * value_src = value->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
|
|
for (int j = 0; j < mHeadDim; j++) {
|
|
@@ -423,11 +577,11 @@ void KVCacheManager::onPushBack(const Tensor * key, const Tensor * value) {
|
|
|
std::function<void(int)> packKV = [=](int tid) {
|
|
|
for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) {
|
|
|
if (mBytes == 2) {
|
|
|
- pack_key<FLOAT16_T>(key, addrOfKey(kv_h), mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantKey, addrOfScale(kv_h), addrOfZeroPoint(kv_h), core);
|
|
|
- pack_value<FLOAT16_T>(value, addrOfValue(kv_h), mMaxLength, mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantValue, core);
|
|
|
+ pack_key<FLOAT16_T>(key, seq_len, kv_h);
|
|
|
+ pack_value<FLOAT16_T>(value, seq_len, kv_h);
|
|
|
} else {
|
|
|
- pack_key<float>(key, addrOfKey(kv_h), mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantKey, addrOfScale(kv_h), addrOfZeroPoint(kv_h), core);
|
|
|
- pack_value<float>(value, addrOfValue(kv_h), mMaxLength, mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantValue, core);
|
|
|
+ pack_key<float>(key, seq_len, kv_h);
|
|
|
+ pack_value<float>(value, seq_len, kv_h);
|
|
|
}
|
|
|
}
|
|
|
};
|