Browse Source

[MNN:Sync] Sync internal Gitlab

xiaying 3 years ago
parent
commit
03c7b5347b
100 changed files with 10236 additions and 2895 deletions
  1. 15 4
      CMakeLists.txt
  2. 1 0
      MNN.podspec
  3. 5 1
      demo/exec/multithread_imgrecog.cpp
  4. 28 3
      demo/exec/pictureRecognition_module.cpp
  5. 6 0
      docker_run.sh
  6. 153 17
      express/Executor.cpp
  7. 0 1
      express/Expr.cpp
  8. 13 4
      express/module/Module.cpp
  9. 12 11
      express/module/PipelineModule.cpp
  10. 3 3
      express/module/PipelineModule.hpp
  11. 20 15
      express/module/StaticModule.cpp
  12. 1 1
      express/module/StaticModule.hpp
  13. 38 0
      include/MNN/expr/Executor.hpp
  14. 5 0
      include/MNN/expr/Module.hpp
  15. 118 0
      package_scripts/linux/build_bridge.sh
  16. 76 27
      package_scripts/linux/build_lib.sh
  17. 141 0
      package_scripts/mac/build_bridge.sh
  18. 100 35
      package_scripts/mac/build_lib.sh
  19. 53 5
      project/ios/MNN.xcodeproj/project.pbxproj
  20. 19 3
      pymnn/CMakeLists.txt
  21. 22 6
      schema/current/CaffeOp_generated.h
  22. 9 0
      schema/current/MNN_generated.h
  23. 33 12
      schema/current/TensorflowOp_generated.h
  24. 1 0
      schema/default/CaffeOp.fbs
  25. 2 0
      schema/default/TensorflowOp.fbs
  26. 5 3
      source/backend/arm82/Arm82Functions.cpp
  27. 10 7
      source/backend/arm82/Arm82Unary.cpp
  28. 167 0
      source/backend/arm82/Arm82Vec.hpp
  29. 351 0
      source/backend/arm82/Arm82WinogradOptFunc.cpp
  30. 3 2
      source/backend/arm82/Arm82WinogradOptFunc.hpp
  31. 1 1
      source/backend/cpu/BinaryUtils.hpp
  32. 3 4
      source/backend/cpu/CPUBackend.cpp
  33. 25 16
      source/backend/cpu/CPUCast.cpp
  34. 2 2
      source/backend/cpu/CPUCast.hpp
  35. 10 2
      source/backend/cpu/CPUConvolution.cpp
  36. 2 2
      source/backend/cpu/CPUConvolution.hpp
  37. 2 10
      source/backend/cpu/CPUDepthwiseConvInt8.cpp
  38. 12 9
      source/backend/cpu/CPUFloatToInt8.cpp
  39. 12 9
      source/backend/cpu/CPUInt8ToFloat.cpp
  40. 11 2
      source/backend/cpu/CPULayerNorm.cpp
  41. 14 18
      source/backend/cpu/CPUOPRegister.cpp
  42. 13 17
      source/backend/cpu/CPURaster.cpp
  43. 1 0
      source/backend/cpu/CPURaster.hpp
  44. 3 2
      source/backend/cpu/CPUScale.cpp
  45. 5 1
      source/backend/cpu/CPUSoftmax.cpp
  46. 8 8
      source/backend/cpu/CPUTensorConvert.cpp
  47. 12 7
      source/backend/cpu/CPUUnary.cpp
  48. 1 1
      source/backend/cpu/arm/CommonOptFunctionNeon.cpp
  49. 20 14
      source/backend/cpu/arm/arm32/MNNExpC8.S
  50. 6 1
      source/backend/cpu/arm/arm32/MNNPackC4ForMatMul_A.S
  51. 0 4
      source/backend/cpu/arm/arm32/MNNPackedSparseMatMulEpx4.S
  52. 319 0
      source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S
  53. 352 0
      source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S
  54. 15 11
      source/backend/cpu/arm/arm64/MNNExpC8.S
  55. 12 1
      source/backend/cpu/arm/arm64/MNNPackC4ForMatMul_A.S
  56. 10 10
      source/backend/cpu/arm/arm64/MNNPackedSparseMatMulEpx4.S
  57. 520 0
      source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S
  58. 1086 0
      source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S
  59. 12 8
      source/backend/cpu/arm/arm64/MNNSoftmax.S
  60. 1 0
      source/backend/cpu/bf16/BF16Functions.cpp
  61. 10 7
      source/backend/cpu/bf16/BF16Unary.cpp
  62. 110 0
      source/backend/cpu/bf16/VecHalf.hpp
  63. 356 0
      source/backend/cpu/bf16/WinogradOptFunctionHalf.cpp
  64. 2 1
      source/backend/cpu/bf16/WinogradOptFunctionHalf.hpp
  65. 53 25
      source/backend/cpu/compute/CommonOptFunction.cpp
  66. 28 3
      source/backend/cpu/compute/CommonOptFunction.h
  67. 115 65
      source/backend/cpu/compute/ConvInt8TiledExecutor.cpp
  68. 33 4
      source/backend/cpu/compute/ConvInt8TiledExecutor.hpp
  69. 16 16
      source/backend/cpu/compute/ConvInt8Winograd.cpp
  70. 10 2
      source/backend/cpu/compute/ConvolutionFloatFactory.cpp
  71. 1 1
      source/backend/cpu/compute/ConvolutionTiledExecutor.cpp
  72. 291 100
      source/backend/cpu/compute/ConvolutionWinograd.cpp
  73. 1 0
      source/backend/cpu/compute/ConvolutionWinograd.hpp
  74. 3 3
      source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp
  75. 1407 11
      source/backend/cpu/compute/Int8FunctionsOpt.cpp
  76. 23 6
      source/backend/cpu/compute/Int8FunctionsOpt.h
  77. 249 0
      source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp
  78. 70 0
      source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp
  79. 5 4
      source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp
  80. 6 1
      source/backend/cpu/compute/SparseConvolutionTiledExecutor.hpp
  81. 14 12
      source/backend/cpu/compute/WinogradInt8Helper.cpp
  82. 413 1
      source/backend/cpu/compute/WinogradOptFunction.cpp
  83. 2 0
      source/backend/cpu/compute/WinogradOptFunction.hpp
  84. 196 51
      source/backend/cpu/x86_x64/AVX2Backend.cpp
  85. 1 0
      source/backend/cpu/x86_x64/AVX2Backend.hpp
  86. 31 23
      source/backend/cpu/x86_x64/AVX2Functions.cpp
  87. 10 0
      source/backend/cpu/x86_x64/AVX2Functions.hpp
  88. 1 1
      source/backend/cpu/x86_x64/CMakeLists.txt
  89. 10 39
      source/backend/cpu/x86_x64/FunctionDispatcher.cpp
  90. 0 1085
      source/backend/cpu/x86_x64/avx/CommonOptFunction.cpp
  91. 6 37
      source/backend/cpu/x86_x64/avx/FunctionSummary.hpp
  92. 702 947
      source/backend/cpu/x86_x64/avx/GemmInt8.cpp
  93. 776 0
      source/backend/cpu/x86_x64/avx/GemmSparse.cpp
  94. 0 46
      source/backend/cpu/x86_x64/avx/MNNMatrixAdd.cpp
  95. 0 21
      source/backend/cpu/x86_x64/avx/MNNMatrixSub.cpp
  96. 265 0
      source/backend/cpu/x86_x64/avx/MathFunctions.cpp
  97. 569 0
      source/backend/cpu/x86_x64/avx/PackedFunction.cpp
  98. 486 0
      source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp
  99. 65 63
      source/backend/cpu/x86_x64/avx/Vec8.hpp
  100. 0 0
      source/backend/cpu/x86_x64/avx/WinogradAVX2.cpp

+ 15 - 4
CMakeLists.txt

@@ -420,9 +420,7 @@ if ((NOT MSVC) AND MNN_HIDDEN)
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility-inlines-hidden -fvisibility=hidden")
     set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden")
     # Omit frame pointer may cause difficult debug
-    if ((NOT APPLE) AND (NOT WIN32))
-        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer")
-    endif()
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer")
 endif()
 if (NOT MSVC)
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math -fno-rtti -fno-exceptions ")
@@ -528,7 +526,9 @@ ENDIF()
 IF(MNN_CUDA)
   add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/source/backend/cuda/)
   list(APPEND MNN_TARGETS MNN_CUDA)
-  list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_CUDA>)
+  if (NOT MSVC)
+    list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_CUDA>)
+  endif()
   list(APPEND MNN_EXTRA_DEPENDS ${MNN_CUDA_LIBS})
 ENDIF()
 
@@ -601,6 +601,17 @@ ELSE()
 ENDIF()
 if (MSVC)
   target_link_options(MNN PRIVATE "/IGNORE:4049,4217")
+  if (MNN_CUDA)
+    if (MNN_BUILD_SHARED_LIBS)
+      target_link_options(MNN PRIVATE "/WHOLEARCHIVE:$<TARGET_FILE:MNN_CUDA>")
+    else()
+      add_custom_command(
+        TARGET MNN
+        POST_BUILD
+        COMMAND lib.exe ARGS /OUT:$<TARGET_FILE:MNN> $<TARGET_FILE:MNN> $<TARGET_FILE:MNN_CUDA>
+      )
+    endif()
+  endif()
 endif()
 if (MNN_ONEDNN)
     add_dependencies(MNN ONEDNN_COMMON ONEDNN_CPU ONEDNN_CPU_X64)

+ 1 - 0
MNN.podspec

@@ -46,6 +46,7 @@ Pod::Spec.new do |s|
   'schema/current/*.{h}',\
   '3rd_party/flatbuffers/include/flatbuffers/*.{h}',\
   'source/core/**/*.{h,c,m,mm,cc,hpp,cpp}',\
+  'source/common/**/*.{h,c,m,mm,cc,hpp,cpp}',\
   'source/utils/**/*.{h,c,m,mm,cc,hpp,cpp}',\
   'source/geometry/**/*.{h,c,m,mm,cc,hpp,cpp}',\
   'source/cv/**/*.{h,c,m,mm,cc,hpp,cpp}',\

+ 5 - 1
demo/exec/multithread_imgrecog.cpp

@@ -47,7 +47,11 @@ int main(int argc, const char* argv[]) {
         threads.emplace_back([&, i]() {
             auto newExe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 1);
             ExecutorScope scope(newExe);
-            std::shared_ptr<Module> tempModule(Module::clone(net.get()));
+            std::shared_ptr<Module> tempModule;
+            {
+                std::unique_lock<std::mutex> _l(printMutex);
+                tempModule.reset(Module::clone(net.get()));
+            }
             // Create Input
             auto input = MNN::Express::_Input({1, 3, 224, 224}, MNN::Express::NC4HW4);
             int size_w   = 224;

+ 28 - 3
demo/exec/pictureRecognition_module.cpp

@@ -9,6 +9,7 @@
 #include <stdio.h>
 #include <MNN/ImageProcess.hpp>
 #include <MNN/expr/Module.hpp>
+#include <MNN/expr/Executor.hpp>
 #include <MNN/expr/ExprCreator.hpp>
 #define MNN_OPEN_TIME_TRACE
 #include <algorithm>
@@ -29,9 +30,31 @@ int main(int argc, const char* argv[]) {
         MNN_PRINT("Usage: ./pictureRecognition_module.out model.mnn input0.jpg input1.jpg input2.jpg ... \n");
         return 0;
     }
-    // Load module
-    std::shared_ptr<MNN::Express::Module> net(MNN::Express::Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1]));
-
+    // Load module with Config
+    /*
+    MNN::Express::Module::BackendInfo bnInfo;
+    bnInfo.type = MNN_FORWARD_CPU;
+    MNN::Express::Module::Config configs;
+    configs.backend = &bnInfo;
+    std::shared_ptr<MNN::Express::Module> net(MNN::Express::Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], &configs));
+    */
+    
+    // Load module with Runtime
+    std::vector<MNN::ScheduleConfig> sConfigs;
+    MNN::ScheduleConfig sConfig;
+    sConfig.type = MNN_FORWARD_AUTO;
+    sConfigs.push_back(sConfig);
+    std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtmgr = std::shared_ptr<MNN::Express::Executor::RuntimeManager>(MNN::Express::Executor::RuntimeManager::createRuntimeManager(sConfigs));
+    if(rtmgr == nullptr) {
+        MNN_ERROR("Empty RuntimeManger\n");
+        return 0;
+    }
+    
+    // Give cache full path which must be Readable and writable
+    rtmgr->setCache(".cachefile");
+    
+    std::shared_ptr<MNN::Express::Module> net(MNN::Express::Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr));
+    
     // Create Input
     int batchSize = argc - 2;
     auto input = MNN::Express::_Input({batchSize, 3, 224, 224}, MNN::Express::NC4HW4);
@@ -81,5 +104,7 @@ int main(int argc, const char* argv[]) {
             MNN_PRINT("%d, %f\n", indice[batch * topK + i], value[batch * topK + i]);
         }
     }
+    rtmgr->updateCache();
+
     return 0;
 }

+ 6 - 0
docker_run.sh

@@ -0,0 +1,6 @@
+# using docker run test
+docker start mnn_ci
+docker exec -i -e TEST_ID=$(pwd | awk -F "/" '{print $(NF-1)}') mnn_ci bash <<'EOF'
+cd ~/yanxing_zhaode/cise/space/$TEST_ID/source && ./test.sh linux
+exit
+EOF

+ 153 - 17
express/Executor.cpp

@@ -9,6 +9,7 @@
 #include <MNN/expr/Executor.hpp>
 #include "core/Session.hpp"
 #include "core/TensorUtils.hpp"
+#include "core/FileLoader.hpp"
 #include "Utils.hpp"
 #include <MNN/AutoTime.hpp>
 #include "core/WrapExecution.hpp"
@@ -67,23 +68,52 @@ void Executor::Profiler::addFlops(const std::string& opType, float flops) {
     iter->second += flops;
 }
 #endif
+
+struct Executor::Cache{
+    AutoStorage<uint8_t> modelBuffer;
+    AutoStorage<uint8_t> cacheBuffer;
+    size_t cacheOffset = 0;
+    std::string cacheFile;
+    size_t lastCacheSize = 0;
+};
+
 void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread) {
     std::lock_guard<std::mutex> _l(mMutex);
-    auto creator = MNNGetExtraRuntimeCreator(type);
-    if (nullptr == creator) {
-        MNN_ERROR("Error to find creator of %d, set CPU default\n", type);
-        type = MNN_FORWARD_CPU;
-        creator = MNNGetExtraRuntimeCreator(type);
+    if(type == MNN_FORWARD_AUTO) {
+        ScheduleConfig sConfig;
+        sConfig.type = type;
+        type = Schedule::getApprociateType(sConfig);
+        
+        auto creator = MNNGetExtraRuntimeCreator(type);
+        MNN_ASSERT(nullptr != creator);
+        Backend::Info info;
+        info.type = type;
+        info.mode = Backend::Info::DIRECT;
+        info.numThread = numberThread;
+        if(type == MNN_FORWARD_OPENCL || type == MNN_FORWARD_METAL) {
+            info.numThread = 4;
+        }
+        info.user = (BackendConfig*)&config;
+        std::shared_ptr<Runtime> bn(creator->onCreate(info));
+        mRuntime.first = bn;
+        mRuntime.second = type;
+    } else {
+        auto creator = MNNGetExtraRuntimeCreator(type);
+        if (nullptr == creator) {
+            MNN_ERROR("Error to find creator of %d, set CPU default\n", type);
+            type = MNN_FORWARD_CPU;
+            creator = MNNGetExtraRuntimeCreator(type);
+        }
+        MNN_ASSERT(nullptr != creator);
+        Backend::Info info;
+        info.type = type;
+        info.mode = Backend::Info::DIRECT;
+        info.numThread = numberThread;
+        info.user = (BackendConfig*)&config;
+        std::shared_ptr<Runtime> bn(creator->onCreate(info));
+        mRuntime.first = bn;
+        mRuntime.second = type;
     }
-    MNN_ASSERT(nullptr != creator);
-    Backend::Info info;
-    info.type = type;
-    info.mode = Backend::Info::DIRECT;
-    info.numThread = numberThread;
-    info.user = (BackendConfig*)&config;
-    std::shared_ptr<Runtime> bn(creator->onCreate(info));
-    mRuntime.first = bn;
-    mRuntime.second = type;
 }
 
 void Executor::gc(GCFlag flag) {
@@ -141,8 +171,8 @@ Executor::Requirement Executor::getRequirement(Expr* expr) const {
 }
 
 static std::once_flag gInitFlag;
+static std::shared_ptr<Executor>* gExecutor = nullptr;
 std::shared_ptr<Executor> Executor::getGlobalExecutor() {
-    static std::shared_ptr<Executor> gExecutor;
     std::call_once(gInitFlag, [&]() {
         auto creator = MNNGetExtraRuntimeCreator(MNN_FORWARD_CPU);
 #ifdef MNN_BUILD_MINI
@@ -153,9 +183,9 @@ std::shared_ptr<Executor> Executor::getGlobalExecutor() {
         info.type = MNN_FORWARD_CPU;
         info.numThread = 1;
         std::shared_ptr<Runtime> bn(creator->onCreate(info));
-        gExecutor.reset(new Executor(bn, MNN_FORWARD_CPU));
+        gExecutor = new std::shared_ptr<Executor>(new Executor(bn, MNN_FORWARD_CPU));
     });
-    return gExecutor;
+    return *gExecutor;
 }
 
 std::shared_ptr<Executor> Executor::newExecutor(MNNForwardType type,
@@ -178,6 +208,112 @@ RuntimeInfo Executor::getRuntime() {
     return info;
 }
 
+static bool loadCache(std::shared_ptr<Runtime> &rt, const void* buffer, size_t size) {
+    auto res = rt->onSetCache(buffer, size);
+    if (res) {
+        return true;
+    }
+    return false;
+}
+static std::pair<const void*, size_t> getCache(std::shared_ptr<Runtime> &rt) {
+    auto res = rt->onGetCache();
+    if (res.first != nullptr) {
+        return res;
+    }
+    return std::make_pair(nullptr, 0);
+}
+
+static void writeCacheFile(std::shared_ptr<Executor::Cache> cache, std::pair<const void*, size_t> buffer) {
+    std::unique_ptr<FileLoader> loader(new FileLoader(cache->cacheFile.c_str()));
+    auto verifyInfo = std::make_pair((const void*)cache->modelBuffer.get(), cache->cacheOffset);
+    bool res = loader->write(verifyInfo, buffer);
+    if (!res) {
+        MNN_ERROR("Write Cache File error!\n");
+        return;
+    }
+}
+
+
+
+Executor::RuntimeManager::RuntimeManager(std::vector<ScheduleConfig> &configs) {
+    mRuntime = Interpreter::createRuntime(configs);
+    mInfo = mRuntime.first.begin()->second;
+}
+
+Executor::RuntimeManager* Executor::RuntimeManager::createRuntimeManager(std::vector<ScheduleConfig> &configs) {
+    if(configs.size() == 0) {
+        MNN_ERROR("Empty runtime config\n");
+        return nullptr;
+    }
+    return new Executor::RuntimeManager(configs);
+}
+
+
+void Executor::RuntimeManager::setCache(std::string cacheName) {
+    mCache.reset(new Cache);
+    mCache->cacheFile = cacheName;
+    if (nullptr == mCache->cacheFile.c_str()) {
+        MNN_ERROR("Empty cacheFile\n");
+        return;
+    }
+    std::unique_ptr<FileLoader> loader(new FileLoader(mCache->cacheFile.c_str()));
+    if (!loader->valid()) {
+        MNN_ERROR("Load Cache file error.\n");
+        return;
+    }
+    bool result = loader->read();
+    if (!result) {
+        MNN_ERROR("Load Cache file error.\n");
+        return;
+    }
+    if (loader->size() == 0) {
+        MNN_ERROR("Load Cache file error.\n");
+        return;
+    }
+    bool success = loader->merge(mCache->cacheBuffer);
+    if (!success) {
+        MNN_ERROR("Alloc memory for Cache error.\n");
+        return;
+    }
+    
+    // load cache
+    bool valid = loadCache(mInfo, mCache->cacheBuffer.get() + mCache->cacheOffset,
+                           mCache->cacheBuffer.size() - mCache->cacheOffset);
+    if(!valid) {
+        // Reset cache
+        loadCache(mInfo, nullptr, 0);
+        MNN_PRINT("Cache invalid, will be reset\n");
+    }
+    
+    mCache->lastCacheSize = mCache->cacheBuffer.size() - mCache->cacheOffset;
+}
+
+void Executor::RuntimeManager::updateCache() {
+    auto buffer = getCache(mInfo);
+    
+    //When current cacheSize bigger than previous, update
+    if (buffer.first != nullptr && buffer.second > mCache->lastCacheSize) {
+        MNN_PRINT("Update cache to %s, size = %zu\n", mCache->cacheFile.c_str(), buffer.second);
+        writeCacheFile(mCache, buffer);
+        mCache->lastCacheSize = buffer.second;
+    }
+    // Reset cache
+    loadCache(mInfo, nullptr, 0);
+}
+
+std::vector<bool> Executor::RuntimeManager::isBackendSupport(const std::vector<MNNForwardType> types) {
+    std::vector<bool> res;
+    for (auto bn : types) {
+        auto rt = MNNGetExtraRuntimeCreator(bn);
+        if (rt != nullptr) {
+            res.push_back(true);
+        } else {
+            res.push_back(false);
+        }
+    }
+    return res;
+}
+
 ErrorCode Executor::computeInfo(Expr* expr) {
     MNN_ASSERT(nullptr != expr);
     MNN_ASSERT(nullptr != expr->get());

+ 0 - 1
express/Expr.cpp

@@ -138,7 +138,6 @@ EXPRP Expr::create(Tensor* tensor, bool own) {
     auto& dstInfo = expr->mInside->mOutputInfos[0];
     expr->mInside->mInfoDirty = false;
     expr->mInside->mContentDirty = false;
-    expr->mInside->mOwnTensor = false;
     return expr;
 }
 

+ 13 - 4
express/module/Module.cpp

@@ -118,6 +118,15 @@ void Module::clearCache() {
 }
 
 Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Module::Config* config) {
+    return load(inputs, outputs, fileName, nullptr, config);
+}
+
+Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config) {
+    return load(inputs, outputs, buffer, length, nullptr, config);
+}
+
+
+Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config) {
     AutoStorage<uint8_t> buffer;
     {
         FileLoader loader(fileName);
@@ -134,10 +143,10 @@ Module* Module::load(const std::vector<std::string>& inputs, const std::vector<s
             return nullptr;
         }
     }
-    return load(inputs, outputs, buffer.get(), buffer.size(), config);
+    return load(inputs, outputs, buffer.get(), buffer.size(), rtMgr, config);
 }
 
-Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config) {
+Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config) {
     // Check Auto Inputs and Outputs
     auto net = GetNet(buffer);
     if (nullptr == net->oplists() || nullptr == net->tensorName()) {
@@ -145,7 +154,7 @@ Module* Module::load(const std::vector<std::string>& inputs, const std::vector<s
         return nullptr;
     }
     if ((!inputs.empty()) && (!outputs.empty())) {
-        return PipelineModule::load(inputs, outputs, buffer, length, config);
+        return PipelineModule::load(inputs, outputs, buffer, length, rtMgr, config);
     }
     std::vector<std::string> newInputs = inputs;
     std::vector<std::string> newOutputs = outputs;
@@ -181,7 +190,7 @@ Module* Module::load(const std::vector<std::string>& inputs, const std::vector<s
             newOutputs.emplace_back(net->tensorName()->GetAsString(index)->str());
         }
     }
-    return PipelineModule::load(newInputs, newOutputs, buffer, length, config);
+    return PipelineModule::load(newInputs, newOutputs, buffer, length, rtMgr, config);
 }
 
 EXPRP Module::CloneContext::getOrClone(EXPRP expr) {

+ 12 - 11
express/module/PipelineModule.cpp

@@ -235,7 +235,7 @@ void PipelineModule::onClearCache() {
     // Do nothing
 }
 
-void PipelineModule::_createSubGraph(const MNN::Net* net, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap) {
+void PipelineModule::_createSubGraph(const MNN::Net* net, std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap) {
     auto subGraphs = net->subgraphs();
     if (nullptr == subGraphs) {
         return;
@@ -271,7 +271,7 @@ void PipelineModule::_createSubGraph(const MNN::Net* net, const Module::Config*
             flatbuffers::FlatBufferBuilder builder(1024);
             auto offset = Net::Pack(builder, _tempNet.get());
             builder.Finish(offset);
-            submodule.reset(PipelineModule::load(subInputs, subOutputs, (const uint8_t*)builder.GetBufferPointer(), builder.GetSize(), config, subGraphMap, true));
+            submodule.reset(PipelineModule::load(subInputs, subOutputs, (const uint8_t*)builder.GetBufferPointer(), builder.GetSize(), rtMgr, config, subGraphMap, true));
             if (graph->name() != nullptr) {
                 submodule->setName(graph->name()->str());
             }
@@ -539,7 +539,7 @@ static std::vector<SubModuleInfo> _createSubModuleInfo(const MNN::Net* net, cons
     return submodule;
 }
 
-static Module* _createSubModule(const MNN::Net* net, const SubModuleInfo& info, const std::map<std::string, SubGraph>& subs, const Module::Config& config, bool inRecurse, std::shared_ptr<Schedule::ScheduleInfo> sharedConst) {
+static Module* _createSubModule(const MNN::Net* net, const SubModuleInfo& info, const std::map<std::string, SubGraph>& subs, std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config& config, bool inRecurse, std::shared_ptr<Schedule::ScheduleInfo> sharedConst) {
     if (1 == info.opList.size()) {
         auto op = net->oplists()->GetAs<Op>(info.opList[0]);
         if (OpType_If == op->type()) {
@@ -590,10 +590,10 @@ static Module* _createSubModule(const MNN::Net* net, const SubModuleInfo& info,
     auto offset = Net::Pack(builder, _tempNet.get());
     builder.Finish(offset);
     _tempNet.reset();
-    return new StaticModule((const uint8_t*)builder.GetBufferPointer(), builder.GetSize(), inputNames, outputNames, config, inRecurse, sharedConst);
+    return new StaticModule((const uint8_t*)builder.GetBufferPointer(), builder.GetSize(), inputNames, outputNames, rtMgr, config, inRecurse, sharedConst);
 }
 
-Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config) {
+Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config) {
     // Create Subgraph
     auto net = GetNet(buffer);
     if (nullptr == net->oplists() || nullptr == net->tensorName()) {
@@ -606,10 +606,11 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
     }
     auto subGraphs = net->subgraphs();
     std::map<std::string, SubGraph> subGraphMap;
-    _createSubGraph(net, config, subGraphMap);
-    return load(inputs, outputs, buffer, length, config, subGraphMap);
+    _createSubGraph(net, rtMgr, config, subGraphMap);
+    return load(inputs, outputs, buffer, length, rtMgr, config, subGraphMap);
 }
-Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap, bool inRecurce) {
+
+Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap, bool inRecurce) {
     std::shared_ptr<Schedule::ScheduleInfo> sharedConst;
     auto net = GetNet(buffer);
     if (!config->dynamic) {
@@ -623,7 +624,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
         }
         if (linear) {
             // Has no control flow and WhereOp, can just use static module
-            return new StaticModule(buffer, length, inputs, outputs, *config, false, sharedConst);
+            return new StaticModule(buffer, length, inputs, outputs, rtMgr, *config, false, sharedConst);
         }
     }
     // Extra Const Tensors
@@ -637,7 +638,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
     sharedConst->allTensors.resize(net->tensorName()->size());
     ErrorCode code = NO_ERROR;
     std::set<int> noneedComputeIndexes;
-    initConstTensors(sharedConst->allTensors, net, defaultBackend.get(), false, code);
+    initConstTensors(sharedConst->allTensors, net, defaultBackend.get(), false, code, Backend::DYNAMIC_SEPERATE);
     for (int i=0; i<sharedConst->allTensors.size(); ++i) {
         if (sharedConst->allTensors[i].get() != nullptr) {
             noneedComputeIndexes.insert(i);
@@ -677,7 +678,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
     auto subModulesInfo = _createSubModuleInfo(net, inputIndexes, outputIndexes, noneedComputeIndexes, sharedConst, initVars);
     std::vector<std::shared_ptr<Module>> subModules(subModulesInfo.size());
     for (int i=0; i<subModulesInfo.size(); ++i) {
-        subModules[i].reset(_createSubModule(net, subModulesInfo[i], subGraphMap, *config, inRecurce, sharedConst));
+        subModules[i].reset(_createSubModule(net, subModulesInfo[i], subGraphMap, rtMgr, *config, inRecurce, sharedConst));
     }
     auto result = new PipelineModule;
     /**

+ 3 - 3
express/module/PipelineModule.hpp

@@ -40,7 +40,7 @@ private:
 class PipelineModule : public Module {
 public:
     typedef std::function<std::pair<std::vector<int>, std::shared_ptr<Module>>(Express::EXPRP)> Transformer;
-    MNN_PUBLIC static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config = nullptr);
+    MNN_PUBLIC static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config = nullptr);
     virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
     virtual void onClearCache() override;
     MNN_PUBLIC std::vector<int> countOutputReference(std::vector<int> outputIndices);
@@ -48,8 +48,8 @@ public:
     MNN_PUBLIC PipelineModule(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs,
                    const Transformer& transformFunction = {});
 private:
-    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap, bool inRecurce = false);
-    static void _createSubGraph(const MNN::Net* net, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap);
+    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap, bool inRecurce = false);
+    static void _createSubGraph(const MNN::Net* net, std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap);
 
     PipelineModule(){}
 

+ 20 - 15
express/module/StaticModule.cpp

@@ -182,7 +182,7 @@ private:
 };
 
 StaticModule::StaticModule(const void* buffer, size_t length, const std::vector<std::string>& inputs,
-                           const std::vector<std::string>& outputs, const Module::Config& moduleconfig, bool copyOutput, std::shared_ptr<Schedule::ScheduleInfo> sharedConst) {
+                           const std::vector<std::string>& outputs, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config& moduleconfig, bool copyOutput, std::shared_ptr<Schedule::ScheduleInfo> sharedConst) {
     setType("StaticModule");
     mResource.reset(new Resource);
     mResource->mInputs = inputs;
@@ -244,13 +244,17 @@ StaticModule::StaticModule(const void* buffer, size_t length, const std::vector<
     }
     
     RuntimeInfo rt;
-    if (moduleconfig.backend == nullptr) {
-        rt = Express::ExecutorScope::Current()->getRuntime();
+    if(rtMgr) {
+        rt = rtMgr->getRuntimeInfo();
     } else {
-        ScheduleConfig sche_config;
-        sche_config.type = moduleconfig.backend->type;
-        sche_config.backendConfig = moduleconfig.backend->config;
-        rt = Interpreter::createRuntime(std::vector<ScheduleConfig>({sche_config}));
+        if (moduleconfig.backend == nullptr) {
+            rt = Express::ExecutorScope::Current()->getRuntime();
+        } else {
+            ScheduleConfig sche_config;
+            sche_config.type = moduleconfig.backend->type;
+            sche_config.backendConfig = moduleconfig.backend->config;
+            rt = Interpreter::createRuntime(std::vector<ScheduleConfig>({sche_config}));
+        }
     }
     // TODO: Add Config
     mResource->mConfig.numThread   = 1;
@@ -390,18 +394,19 @@ std::vector<Express::VARP> StaticModule::onForward(const std::vector<Express::VA
         bool isQuant = (quantAttr && TensorUtils::DataTypeToHalideType(quantAttr->type) == currentTensor->getType());
         // copy the data when reused as input tensor with data;
         if (currentTensor->elementSize() > 0 && (mResource->mReusedTensors.find(mResource->mOutputFromTensor[i]) != mResource->mReusedTensors.end() || mResource->mCopyOutput || isQuant)) {
-            std::shared_ptr<Tensor> tmpTensor(new Tensor(currentTensor, currentTensor->getDimensionType(), true));
-            auto des                 = TensorUtils::getDescribe(mOutputTensors[i]);
+            auto des = TensorUtils::getDescribe(mOutputTensors[i]);
+            auto tmpTensor = new Tensor(currentTensor, currentTensor->getDimensionType(), false);
+            TensorUtils::getDescribe(tmpTensor)->dimensionFormat = des->dimensionFormat;
+            TensorUtils::getDescribe(tmpTensor)->tensorArrayAttr = des->tensorArrayAttr;
+            tmpTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(tmpTensor->size(), MNN_MEMORY_ALIGN_DEFAULT);
+            TensorUtils::getDescribe(tmpTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST;
             if (nullptr != des->backend) {
-                currentTensor->copyToHostTensor(tmpTensor.get());
+                currentTensor->copyToHostTensor(tmpTensor);
             } else {
-                MNNCPUCopyBuffer(currentTensor, tmpTensor.get());
+                MNNCPUCopyBuffer(currentTensor, tmpTensor);
             }
-            TensorUtils::getDescribe(tmpTensor.get())->dimensionFormat = des->dimensionFormat;
-            TensorUtils::getDescribe(tmpTensor.get())->tensorArrayAttr = des->tensorArrayAttr;
             outputs[mResource->mOutputFromTensor[i]] =
-                Express::Variable::create(Express::Expr::create(tmpTensor.get()), 0);
-            mOutputTensorsWrap[i] = tmpTensor;
+                Express::Variable::create(Express::Expr::create(tmpTensor, true), 0);
         } else {
             outputs[mResource->mOutputFromTensor[i]] = Express::Variable::create(Express::Expr::create(mOutputTensors[i]));
         }

+ 1 - 1
express/module/StaticModule.hpp

@@ -20,7 +20,7 @@ namespace Express {
 struct BufferStorage;
 class StaticModule : public Module {
 public:
-    StaticModule(const void* buffer, size_t length, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const Module::Config& config, bool copyOutput, std::shared_ptr<Schedule::ScheduleInfo> sharedConst);
+    StaticModule(const void* buffer, size_t length, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config& config, bool copyOutput, std::shared_ptr<Schedule::ScheduleInfo> sharedConst);
     virtual ~ StaticModule();
     virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
     void setReusedTensors(std::set<int> reused);

+ 38 - 0
include/MNN/expr/Executor.hpp

@@ -56,6 +56,44 @@ public:
     void addOpFlops(const std::string& type, float flops);
     class Profiler;
     static RuntimeInfo getRuntime();
+    
+    struct Cache;
+    class RuntimeManager {
+    public:
+        RuntimeManager(std::vector<ScheduleConfig> &configs);
+        ~RuntimeManager() {};
+        
+        /**
+         * @param configs: schedule configs.
+         * @param cacheName: full path for cache file. Note: should choose location for reading and writing.
+         */
+        static RuntimeManager* createRuntimeManager(std::vector<ScheduleConfig> &configs);
+
+        /**
+         * @brief set cache file. when file not exist -- create it, when file exist -- load it.
+         * When should use : When choose GPU backend or use AUTO backend.
+         * Calling Position: calling after createRuntimeManager.
+         */
+        void setCache(std::string cacheName);
+        
+        /**
+         * @brief update cache file
+         * When should use   : Together with setCache API. calling for first inference and when input shape is changed.
+         * Calling Position  : calling after inference done.
+         */
+        void updateCache();
+        std::vector<bool> isBackendSupport(const std::vector<MNNForwardType> type);
+        RuntimeInfo getRuntimeInfo() {
+            return mRuntime;
+        }
+    private:
+        RuntimeInfo mRuntime;
+        std::shared_ptr<Runtime> mInfo;
+        std::shared_ptr<Cache> mCache;
+        
+    };
+
+    
 private:
     void _makeCache(const std::vector<EXPRP>& outputs, bool forceCPU);
     void _create(const std::vector<EXPRP>& outputs, std::set<std::shared_ptr<Executor::ComputeCache>>&& inputCaches, std::set<std::shared_ptr<Expr::Inside>>&& inputNode, bool forceCPU);

+ 5 - 0
include/MNN/expr/Module.hpp

@@ -13,6 +13,7 @@
 #include <unordered_map>
 
 #include <MNN/expr/Expr.hpp>
+#include <MNN/expr/Executor.hpp>
 #include <MNN/MNNForwardType.h>
 
 namespace MNN {
@@ -68,6 +69,10 @@ public:
     };
     static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr);
     static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr);
+    // Shared RuntimeManager
+    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
+    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
+    
     static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});
 
     static Module* clone(const Module* module, const bool shareParams = false);

+ 118 - 0
package_scripts/linux/build_bridge.sh

@@ -0,0 +1,118 @@
+#!/bin/bash
+
+# MNN
+#  |--- Debug
+#  |      |--- libmnnpybridge.a
+#  |      |--- libmnnpybridge.so
+#  |
+#  |--- Release
+#         |--- libmnnpybridge.a
+#         |--- libmnnpybridge.so
+
+set -e
+
+usage() {
+    echo "Usage: $0 -i mnn_path -o path [-t build_type -t lib_type]"
+    echo -e "\t-i MNN library path"
+    echo -e "\t-o package files output directory"
+    echo -e "\t-t build type (debug/release), lib_type (dynamic/static), build all when unspecify"
+    exit 1
+}
+
+build_all=true
+while getopts "i:o:ft:h" opt; do
+  case "$opt" in
+    i ) mnn_path=$OPTARG ;;
+    o ) path=$OPTARG ;;
+    t ) build_all=""
+        case "$OPTARG" in
+            "debug"|"release" ) build_type=$OPTARG ;;
+            "dynamic"|"static" ) lib_type=$OPTARG ;;
+        esac ;;
+    h|? ) usage ;;
+  esac
+done
+
+if [ -z $build_all ] && ([ -z $build_type ] || [ -z $lib_type ]); then
+    echo "build_type(debug/release) and lib_type(dynamic/static) should be set or not-set together"
+    exit 1
+fi
+
+rm -rf $path && mkdir -p $path
+pushd $path
+mkdir -p include wrapper lib/Debug lib/Release
+popd
+PACKAGE_PATH=$(realpath $path)
+MNN_PACKAGE_PATH=$(realpath $mnn_path)
+
+pushd pymnn/3rd_party
+rm -rf MNN && mkdir -p MNN/lib
+cp -r $MNN_PACKAGE_PATH/* MNN/lib
+cp -r ../../include MNN
+popd
+
+cp pymnn/src/MNNPyBridge.h $PACKAGE_PATH/include
+rm -rf /tmp/mnn_py && mkdir -p /tmp/mnn_py
+cp -r pymnn/pip_package/MNN /tmp/mnn_py
+pushd /tmp/mnn_py
+find . -name __pycache__ | xargs rm -rf
+pushd MNN
+rm -rf tools
+cat __init__.py | sed '/from . import tools/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+rm -rf data
+cat __init__.py | sed '/from . import data/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+rm -rf optim
+cat __init__.py | sed '/from . import optim/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+python -c "import compileall; compileall.compile_dir('/tmp/mnn_py/MNN', force=True)"
+find . -name "*.py" | xargs rm -rf
+popd
+cp -r MNN $PACKAGE_PATH/wrapper
+popd
+
+CMAKE_ARGS="-DPYMNN_USE_ALINNPYTHON=ON -DPYMNN_RUNTIME_CHECK_VM=ON -DPYMNN_EXPR_API=ON -DPYMNN_NUMPY_USABLE=ON -DPYMNN_TRAIN_API=OFF"
+
+rm -rf mnnpybridge_build && mkdir mnnpybridge_build
+pushd mnnpybridge_build
+
+log() {
+    echo "==================================="
+    echo "Build mnnpybridge $1"
+    echo "==================================="
+}
+
+# Debug Dynamic
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "dynamic" ]; then
+    log "debug + dynamic"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON ../pymnn && make -j8
+    cp libmnnpybridge.so $PACKAGE_PATH/lib/Debug
+fi
+
+# Debug Static
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "static" ]; then
+    log "debug + static"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF ../pymnn && make -j8
+    cp libmnnpybridge.a $PACKAGE_PATH/lib/Debug
+fi
+
+# Release Dynamic
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "dynamic" ]; then
+    log "release + dynamic"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON ../pymnn && make -j8
+    cp libmnnpybridge.so $PACKAGE_PATH/lib/Release
+fi
+
+# Release Static
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "static" ]; then
+    log "release + static"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF ../pymnn && make -j8
+    cp libmnnpybridge.a $PACKAGE_PATH/lib/Release
+fi
+
+popd

+ 76 - 27
package_scripts/linux/build_lib.sh

@@ -1,3 +1,5 @@
+#!/bin/bash
+
 # MNN
 #  |--- Debug
 #  |      |--- libMNN.a
@@ -10,22 +12,40 @@
 set -e
 
 usage() {
-    echo "Usage: $0 -o path [-b]"
+    echo "Usage: $0 -o path [-b backends] [-s] [-c] [-t build_type -t lib_type [-c]]"
     echo -e "\t-o package files output directory"
-    echo -e "\t-b opencl backend"
+    echo -e "\t-b extra backends to support (opencl, opengl, vulkan, onednn, avx512, coreml)"
+    echo -e "\t-s re-generate schema"
+    echo -e "\t-c clean build folder"
+    echo -e "\t-t build type (debug/release), lib_type (dynamic/static), build all when unspecify"
     exit 1
 }
 
-while getopts "o:hb" opt; do
+build_all=true
+while getopts "o:b:sct:h" opt; do
   case "$opt" in
     o ) path=$OPTARG ;;
-    b ) opencl=true ;;
+    b ) IFS="," read -a backends <<< $OPTARG ;;
+    s ) clean_schema=true ;;
+    c ) clean_build=true ;;
+    t ) build_all=""
+        case "$OPTARG" in
+            "debug"|"release" ) build_type=$OPTARG ;;
+            "dynamic"|"static" ) lib_type=$OPTARG ;;
+        esac ;;
     h|? ) usage ;;
   esac
 done
 
+if [ -z $build_all ] && ([ -z $build_type ] || [ -z $lib_type ]); then
+    echo "build_type(debug/release) and lib_type(dynamic/static) should be set or not-set together"
+    exit 1
+fi
+
 # clear and create package directory
-./schema/generate.sh
+if [ $clean_schema ]; then
+    ./schema/generate.sh
+fi
 rm -rf $path && mkdir -p $path
 mkdir -p $path/Debug
 mkdir -p $path/Release
@@ -33,31 +53,60 @@ mkdir -p $path/Release
 PACKAGE_PATH=$(realpath $path)
 
 CMAKE_ARGS="-DMNN_SEP_BUILD=OFF"
-if [ ! -z $opencl ]; then
-  CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENCL=ON"
+if [ "$backends" ]; then
+    for backend in $backends; do
+        case $backend in
+            "opencl" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENCL=ON" ;;
+            "opengl" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENGL=ON" ;;
+            "vulkan" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_VULKAN=ON" ;;
+            "onednn" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_ONEDNN=ON" ;;
+            "avx512" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_AVX512=ON" ;;
+            "coreml" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_COREML=ON" ;;
+        esac
+    done
 fi
 
-rm -rf build && mkdir build
+if [ $clean_build ]; then
+    rm -rf build && mkdir build
+fi
 pushd build
 
-# Debug Dynamic MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON .. && make -j24
-cp libMNN.so $PACKAGE_PATH/Debug
-
-# Debug Static MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j24
-cp libMNN.a $PACKAGE_PATH/Debug
-
-# Release Dynamic MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON .. && make -j24
-cp libMNN.so $PACKAGE_PATH/Release
-
-# Release Static MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j24
-cp libMNN.a $PACKAGE_PATH/Release
+log() {
+    echo "==================================="
+    echo "Build MNN (CPU $backends) $1"
+    echo "==================================="
+}
+
+# Debug Dynamic
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "dynamic" ]; then
+    log "debug + dynamic"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON .. && make -j24
+    cp libMNN.so $PACKAGE_PATH/Debug
+fi
+
+# Debug Static
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "static" ]; then
+    log "debug + static"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j24
+    cp libMNN.a $PACKAGE_PATH/Debug
+fi
+
+# Release Dynamic
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "dynamic" ]; then
+    log "release + dynamic"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON .. && make -j24
+    cp libMNN.so $PACKAGE_PATH/Release
+fi
+
+# Release Static
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "static" ]; then
+    log "release + static"
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j24
+    cp libMNN.a $PACKAGE_PATH/Release
+fi
 
 popd

+ 141 - 0
package_scripts/mac/build_bridge.sh

@@ -0,0 +1,141 @@
+#!/bin/bash
+
+# MNN
+#  |--- Debug
+#  |      |--- Dynamic
+#  |      |--- Static
+#  |
+#  |--- Release
+#         |--- Dynamic
+#         |--- Static
+
+set -e
+
+usage() {
+    echo "Usage: $0 -i mnn_path -o path [-f] [-t build_type -t lib_type]"
+    echo -e "\t-i MNN library path"
+    echo -e "\t-o package files output directory"
+    echo -e "\t-f mnnpybridge.framework, otherwise .dylib or .a"
+    echo -e "\t-t build type (debug/release), lib_type (dynamic/static), build all when unspecify"
+    exit 1
+}
+
+build_all=true
+while getopts "i:o:ft:h" opt; do
+  case "$opt" in
+    i ) mnn_path=$OPTARG ;;
+    o ) path=$OPTARG ;;
+    f ) fmwk=true ;;
+    t ) build_all=""
+        case "$OPTARG" in
+            "debug"|"release" ) build_type=$OPTARG ;;
+            "dynamic"|"static" ) lib_type=$OPTARG ;;
+        esac ;;
+    h|? ) usage ;;
+  esac
+done
+
+if [ -z $build_all ] && ([ -z $build_type ] || [ -z $lib_type ]); then
+    echo "build_type(debug/release) and lib_type(dynamic/static) should be set or not-set together"
+    exit 1
+fi
+
+rm -rf $path && mkdir -p $path
+pushd $path
+mkdir include wrapper lib
+popd
+PACKAGE_PATH=$(realpath $path)
+MNN_PACKAGE_PATH=$(realpath $mnn_path)
+
+pushd pymnn/3rd_party
+rm -rf MNN && mkdir -p MNN/lib
+cp -r $MNN_PACKAGE_PATH/* MNN/lib
+cp -r ../../include MNN
+popd
+
+cp pymnn/src/MNNPyBridge.h $PACKAGE_PATH/include
+rm -rf /tmp/mnn_py && mkdir -p /tmp/mnn_py
+cp -r pymnn/pip_package/MNN /tmp/mnn_py
+pushd /tmp/mnn_py
+find . -name __pycache__ | xargs rm -rf
+pushd MNN
+rm -rf tools
+cat __init__.py | sed '/from . import tools/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+rm -rf data
+cat __init__.py | sed '/from . import data/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+rm -rf optim
+cat __init__.py | sed '/from . import optim/d' > __init__.py.tmp
+mv __init__.py.tmp __init__.py
+python -c "import compileall; compileall.compile_dir('/tmp/mnn_py/MNN', force=True)"
+find . -name "*.py" | xargs rm -rf
+popd
+cp -r MNN $PACKAGE_PATH/wrapper
+popd
+
+CMAKE_ARGS="-DPYMNN_USE_ALINNPYTHON=ON -DPYMNN_RUNTIME_CHECK_VM=ON -DPYMNN_EXPR_API=ON -DPYMNN_NUMPY_USABLE=ON -DPYMNN_TRAIN_API=OFF"
+if [ $fmwk ]; then
+    CMAKE_ARGS="$CMAKE_ARGS -DDEPEND_AAPL_FMWK=ON"
+fi
+
+rm -rf mnnpybridge_build && mkdir mnnpybridge_build
+pushd mnnpybridge_build
+
+deploy() {
+    _path=$1
+    if [ $fmwk ]; then
+        cp -R mnnpybridge.framework $_path
+        return
+    fi
+    _lib_type=$2
+    if [ $_lib_type = "dynamic" ]; then
+        cp libmnnpybridge.dylib $_path
+    else
+        cp libmnnpybridge.a $_path
+    fi
+}
+
+log() {
+    echo "==================================="
+    echo "Build mnnpybridge $1"
+    echo "==================================="
+}
+
+# Debug Dynamic
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "dynamic" ]; then
+    log "debug + dynamic"
+    pushd $PACKAGE_PATH/lib && mkdir -p Debug/Dynamic && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON ../pymnn && make -j8
+    deploy $PACKAGE_PATH/lib/Debug/Dynamic "dynamic"
+fi
+
+# Debug Static
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "static" ]; then
+    log "debug + static"
+    pushd $PACKAGE_PATH/lib && mkdir -p Debug/Static && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF ../pymnn && make -j8
+    deploy $PACKAGE_PATH/lib/Debug/Static "static"
+fi
+
+# Release Dynamic
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "dynamic" ]; then
+    log "release + dynamic"
+    pushd $PACKAGE_PATH/lib && mkdir -p Release/Dynamic && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON ../pymnn && make -j8
+    deploy $PACKAGE_PATH/lib/Release/Dynamic "dynamic"
+fi
+
+# Release Static
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "static" ]; then
+    log "release + static"
+    pushd $PACKAGE_PATH/lib && mkdir -p Release/Static && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF ../pymnn && make -j8
+    deploy $PACKAGE_PATH/lib/Release/Static "static"
+fi
+
+popd

+ 100 - 35
package_scripts/mac/build_lib.sh

@@ -1,3 +1,5 @@
+#!/bin/bash
+
 # MNN
 #  |--- Debug
 #  |      |--- Dynamic
@@ -6,63 +8,126 @@
 #  |--- Release
 #         |--- Dynamic
 #         |--- Static
-# Only have MNN.framework
 
 set -e
 
 usage() {
-    echo "Usage: $0 -o path [-b]"
+    echo "Usage: $0 -o path [-b backends] [-f] [-s] [-c] [-t build_type -t lib_type [-c]]"
     echo -e "\t-o package files output directory"
-    echo -e "\t-b opencl backend"
+    echo -e "\t-b extra backends to support (opencl, opengl, vulkan, onednn, avx512, coreml)"
+    echo -e "\t-f MNN.framework, otherwise .dylib or .a"
+    echo -e "\t-s re-generate schema"
+    echo -e "\t-c clean build folder"
+    echo -e "\t-t build type (debug/release), lib_type (dynamic/static), build all when unspecify"
     exit 1
 }
 
-while getopts "o:hb" opt; do
+build_all=true
+while getopts "o:b:fsct:h" opt; do
   case "$opt" in
     o ) path=$OPTARG ;;
-    b ) opencl=true ;;
+    b ) IFS="," read -a backends <<< $OPTARG ;;
+    f ) fmwk=true ;;
+    s ) clean_schema=true ;;
+    c ) clean_build=true ;;
+    t ) build_all=""
+        case "$OPTARG" in
+            "debug"|"release" ) build_type=$OPTARG ;;
+            "dynamic"|"static" ) lib_type=$OPTARG ;;
+        esac ;;
     h|? ) usage ;;
   esac
 done
 
+if [ -z $build_all ] && ([ -z $build_type ] || [ -z $lib_type ]); then
+    echo "build_type(debug/release) and lib_type(dynamic/static) should be set or not-set together"
+    exit 1
+fi
+
 # clear and create package directory
-./schema/generate.sh
+if [ $clean_schema ]; then
+    ./schema/generate.sh
+fi
 rm -rf $path && mkdir -p $path
-pushd $path
-mkdir -p Debug/Dynamic
-mkdir -p Debug/Static
-mkdir -p Release/Dynamic
-mkdir -p Release/Static
-popd
 
 PACKAGE_PATH=$(realpath $path)
 
-CMAKE_ARGS="-DMNN_SEP_BUILD=OFF -DMNN_AAPL_FMWK=ON"
-if [ ! -z $opencl ]; then
-  CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENCL=ON"
+CMAKE_ARGS="-DMNN_SEP_BUILD=OFF"
+if [ $fmwk ]; then
+    CMAKE_ARGS="$CMAKE_ARGS -DMNN_AAPL_FMWK=ON"
+fi
+if [ "$backends" ]; then
+    for backend in $backends; do
+        case $backend in
+            "opencl" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENCL=ON" ;;
+            "opengl" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_OPENGL=ON" ;;
+            "vulkan" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_VULKAN=ON" ;;
+            "onednn" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_ONEDNN=ON" ;;
+            "avx512" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_AVX512=ON" ;;
+            "coreml" ) CMAKE_ARGS="$CMAKE_ARGS -DMNN_COREML=ON" ;;
+        esac
+    done
 fi
 
-rm -rf build && mkdir build
+if [ $clean_build ]; then
+    rm -rf pymnn && mkdir build
+fi
 pushd build
 
-# Debug Dynamic MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON .. && make -j8
-cp -R MNN.framework $PACKAGE_PATH/Debug/Dynamic
-
-# Debug Static MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j8
-cp -R MNN.framework $PACKAGE_PATH/Debug/Static
-
-# Release Dynamic MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON .. && make -j8
-cp -R MNN.framework $PACKAGE_PATH/Release/Dynamic
-
-# Release Static MNN.framework
-[ -f CMakeCache.txt ] && rm CMakeCache.txt
-cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j8
-cp -R MNN.framework $PACKAGE_PATH/Release/Static
+deploy() {
+    _path=$1
+    if [ $fmwk ]; then
+        cp -R MNN.framework $_path
+        return
+    fi
+    _lib_type=$2
+    if [ $_lib_type = "dynamic" ]; then
+        cp libMNN.dylib $_path
+    else
+        cp libMNN.a $_path
+    fi
+}
+
+log() {
+    echo "==================================="
+    echo "Build MNN (CPU $backends) $1"
+    echo "==================================="
+}
+
+# Debug Dynamic
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "dynamic" ]; then
+    log "debug + dynamic"
+    pushd $PACKAGE_PATH && mkdir -p Debug/Dynamic && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=ON .. && make -j8
+    deploy $PACKAGE_PATH/Debug/Dynamic "dynamic"
+fi
+
+# Debug Static
+if [ $build_all ] || [ $build_type = "debug" -a $lib_type = "static" ]; then
+    log "debug + static"
+    pushd $PACKAGE_PATH && mkdir -p Debug/Static && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j8
+    deploy $PACKAGE_PATH/Debug/Static "static"
+fi
+
+# Release Dynamic
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "dynamic" ]; then
+    log "release + dynamic"
+    pushd $PACKAGE_PATH && mkdir -p Release/Dynamic && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=ON .. && make -j8
+    deploy $PACKAGE_PATH/Release/Dynamic "dynamic"
+fi
+
+# Release Static
+if [ $build_all ] || [ $build_type = "release" -a $lib_type = "static" ]; then
+    log "release + static"
+    pushd $PACKAGE_PATH && mkdir -p Release/Static && popd
+    [ -f CMakeCache.txt ] && rm CMakeCache.txt
+    cmake $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_BUILD_SHARED_LIBS=OFF .. && make -j8
+    deploy $PACKAGE_PATH/Release/Static "static"
+fi
 
 popd

+ 53 - 5
project/ios/MNN.xcodeproj/project.pbxproj

@@ -11,6 +11,7 @@
 		11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */ = {isa = PBXBuildFile; fileRef = 11A01A05258785EA00745FA7 /* MNNVectorTop1Float.S */; };
 		11A01A0C258785FB00745FA7 /* MNNVectorTop1Float.S in Sources */ = {isa = PBXBuildFile; fileRef = 11A01A0A258785FB00745FA7 /* MNNVectorTop1Float.S */; };
 		11A01A0D258785FB00745FA7 /* MNNVectorTop1Int32.S in Sources */ = {isa = PBXBuildFile; fileRef = 11A01A0B258785FB00745FA7 /* MNNVectorTop1Int32.S */; };
+		19BCFD2226B10015001FCE93 /* MetalCache_generated.h in Headers */ = {isa = PBXBuildFile; fileRef = 19BCFD2126B10015001FCE93 /* MetalCache_generated.h */; };
 		1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */ = {isa = PBXBuildFile; fileRef = 1F501F722397BA5A004E8721 /* HalideRuntime.h */; settings = {ATTRIBUTES = (Public, ); }; };
 		1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */ = {isa = PBXBuildFile; fileRef = 1F501F732397BA5A004E8721 /* MNNDefine.h */; settings = {ATTRIBUTES = (Public, ); }; };
 		1F501F812397BA5B004E8721 /* AutoTime.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 1F501F742397BA5A004E8721 /* AutoTime.hpp */; settings = {ATTRIBUTES = (Public, ); }; };
@@ -280,6 +281,15 @@
 		48FD034A246AA40300456AF5 /* GeometryConvert.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 48FD0349246AA40300456AF5 /* GeometryConvert.cpp */; };
 		48FD12BE2466A88D009E9102 /* GeometryImageOp.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 48FD12BC2466A88C009E9102 /* GeometryImageOp.cpp */; };
 		48FD12BF2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 48FD12BD2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp */; };
+		4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */; };
+		4A5BEC6126AAB3B30032F6BD /* MemoryFormater.h in Headers */ = {isa = PBXBuildFile; fileRef = 4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */; };
+		4A5BEC6426AAB4B30032F6BD /* ModuleTest.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4A5BEC6326AAB4B30032F6BD /* ModuleTest.cpp */; };
+		4AF4FB24269ED235005BA97B /* SparseConvInt8TiledExecutor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB20269ED234005BA97B /* SparseConvInt8TiledExecutor.cpp */; };
+		4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4AF4FB22269ED234005BA97B /* SparseConvInt8TiledExecutor.hpp */; };
+		4AF4FB29269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB27269ED243005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */; };
+		4AF4FB2A269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB28269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */; };
+		4AF4FB2D269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */; };
+		4AF4FB2E269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */; };
 		4D4DAE68263905390060D37E /* CoreMLDefine.h in Headers */ = {isa = PBXBuildFile; fileRef = 4D4DAE67263905390060D37E /* CoreMLDefine.h */; };
 		4D6D7FC7265688E200F80814 /* MNNPackC4ForMatMul_A_BF16.S in Sources */ = {isa = PBXBuildFile; fileRef = 4D6D7FC6265688E200F80814 /* MNNPackC4ForMatMul_A_BF16.S */; };
 		4D6D7FC9265688EA00F80814 /* MNNPackedSparseMatMulEpx1.S in Sources */ = {isa = PBXBuildFile; fileRef = 4D6D7FC8265688EA00F80814 /* MNNPackedSparseMatMulEpx1.S */; };
@@ -760,6 +770,7 @@
 		11A01A05258785EA00745FA7 /* MNNVectorTop1Float.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNVectorTop1Float.S; sourceTree = "<group>"; };
 		11A01A0A258785FB00745FA7 /* MNNVectorTop1Float.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNVectorTop1Float.S; sourceTree = "<group>"; };
 		11A01A0B258785FB00745FA7 /* MNNVectorTop1Int32.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNVectorTop1Int32.S; sourceTree = "<group>"; };
+		19BCFD2126B10015001FCE93 /* MetalCache_generated.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = MetalCache_generated.h; path = schema/current/MetalCache_generated.h; sourceTree = "<group>"; };
 		1F501F722397BA5A004E8721 /* HalideRuntime.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = HalideRuntime.h; path = MNN/HalideRuntime.h; sourceTree = "<group>"; };
 		1F501F732397BA5A004E8721 /* MNNDefine.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = MNNDefine.h; path = MNN/MNNDefine.h; sourceTree = "<group>"; };
 		1F501F742397BA5A004E8721 /* AutoTime.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = AutoTime.hpp; path = MNN/AutoTime.hpp; sourceTree = "<group>"; };
@@ -1027,6 +1038,15 @@
 		48FD0349246AA40300456AF5 /* GeometryConvert.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryConvert.cpp; sourceTree = "<group>"; };
 		48FD12BC2466A88C009E9102 /* GeometryImageOp.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryImageOp.cpp; sourceTree = "<group>"; };
 		48FD12BD2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryConv2DBackPropFilter.cpp; sourceTree = "<group>"; };
+		4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CommonCompute.hpp; sourceTree = "<group>"; };
+		4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = MemoryFormater.h; sourceTree = "<group>"; };
+		4A5BEC6326AAB4B30032F6BD /* ModuleTest.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ModuleTest.cpp; sourceTree = "<group>"; };
+		4AF4FB20269ED234005BA97B /* SparseConvInt8TiledExecutor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = SparseConvInt8TiledExecutor.cpp; sourceTree = "<group>"; };
+		4AF4FB22269ED234005BA97B /* SparseConvInt8TiledExecutor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = SparseConvInt8TiledExecutor.hpp; sourceTree = "<group>"; };
+		4AF4FB27269ED243005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx1.S; sourceTree = "<group>"; };
+		4AF4FB28269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx4.S; sourceTree = "<group>"; };
+		4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx1.S; sourceTree = "<group>"; };
+		4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx4.S; sourceTree = "<group>"; };
 		4D4DAE67263905390060D37E /* CoreMLDefine.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CoreMLDefine.h; sourceTree = "<group>"; };
 		4D6D7FC6265688E200F80814 /* MNNPackC4ForMatMul_A_BF16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4ForMatMul_A_BF16.S; sourceTree = "<group>"; };
 		4D6D7FC8265688EA00F80814 /* MNNPackedSparseMatMulEpx1.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseMatMulEpx1.S; sourceTree = "<group>"; };
@@ -1585,6 +1605,7 @@
 		4829A2CA23CC26AD00623BF5 /* expr */ = {
 			isa = PBXGroup;
 			children = (
+				4A5BEC6326AAB4B30032F6BD /* ModuleTest.cpp */,
 				4829A2CB23CC26AD00623BF5 /* MatMulTest.cpp */,
 				4829A2CC23CC26AD00623BF5 /* GatherTest.cpp */,
 				4829A2CD23CC26AD00623BF5 /* MatrixBandTest.cpp */,
@@ -1671,6 +1692,7 @@
 		488873A8215B639D0079B12E /* source */ = {
 			isa = PBXGroup;
 			children = (
+				4A5BEC6226AAB3D70032F6BD /* common */,
 				4D9A931B26255BDA00F9B43C /* coreml */,
 				6A131E3C2582331C002EC3D6 /* plugin */,
 				489D7A152550FDC800AD896A /* metal */,
@@ -1938,6 +1960,7 @@
 				489D7A2C2550FDC800AD896A /* MetalConvolution.mm */,
 				489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */,
 				489D7A2E2550FDC800AD896A /* MetalReLU.hpp */,
+				19BCFD2126B10015001FCE93 /* MetalCache_generated.h */,
 				489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */,
 				489D7A302550FDC800AD896A /* MetalPooling.hpp */,
 				489D7A312550FDC800AD896A /* MetalPReLU.hpp */,
@@ -2047,6 +2070,15 @@
 			path = ../../../test/speed;
 			sourceTree = "<group>";
 		};
+		4A5BEC6226AAB3D70032F6BD /* common */ = {
+			isa = PBXGroup;
+			children = (
+				4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */,
+				4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */,
+			);
+			path = common;
+			sourceTree = "<group>";
+		};
 		4D9A931B26255BDA00F9B43C /* coreml */ = {
 			isa = PBXGroup;
 			children = (
@@ -2310,6 +2342,8 @@
 		92FF013A23AA0B4E00AC97F6 /* arm32 */ = {
 			isa = PBXGroup;
 			children = (
+				4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */,
+				4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */,
 				4DD179392694076700B0098F /* MNNSoftmax.S */,
 				4D6D7FCA265688F600F80814 /* MNNPackedSparseMatMulEpx4.S */,
 				4D6D7FC8265688EA00F80814 /* MNNPackedSparseMatMulEpx1.S */,
@@ -2375,6 +2409,8 @@
 		92FF017C23AA0B4E00AC97F6 /* arm64 */ = {
 			isa = PBXGroup;
 			children = (
+				4AF4FB27269ED243005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */,
+				4AF4FB28269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */,
 				4DD1793B2694078000B0098F /* MNNSoftmax.S */,
 				48CA2F542681844C003A1796 /* MNNPackC8FP16.S */,
 				48CA2F552681844C003A1796 /* MNNUnpackC8FP16.S */,
@@ -2451,6 +2487,8 @@
 		92FF021B23AA0B5600AC97F6 /* compute */ = {
 			isa = PBXGroup;
 			children = (
+				4AF4FB20269ED234005BA97B /* SparseConvInt8TiledExecutor.cpp */,
+				4AF4FB22269ED234005BA97B /* SparseConvInt8TiledExecutor.hpp */,
 				C4EF5FB92657A9F00094235C /* WinogradInt8Helper.cpp */,
 				C4EF5FB82657A9EF0094235C /* WinogradInt8Helper.hpp */,
 				C4EF5FB22657A9E70094235C /* ConvInt8TiledExecutor.cpp */,
@@ -2666,6 +2704,7 @@
 				92FF03B323AA0B5A00AC97F6 /* ConvolutionDepthwise3x3.hpp in Headers */,
 				4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */,
 				92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */,
+				4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */,
 				92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */,
 				48C84BA0250F725600EE7666 /* InitNet.hpp in Headers */,
 				92FF03C623AA0B5A00AC97F6 /* CPUNonMaxSuppressionV2.hpp in Headers */,
@@ -2721,6 +2760,7 @@
 				4D9A935926255BDA00F9B43C /* DataStructures.pb-c.h in Headers */,
 				C43C82302518951800A0FF84 /* ImageFloatBlitter.hpp in Headers */,
 				489D7A972550FDC900AD896A /* MetalConvolutionDepthwise.hpp in Headers */,
+				4A5BEC6126AAB3B30032F6BD /* MemoryFormater.h in Headers */,
 				489D7AB42550FDC900AD896A /* MetalBinary.hpp in Headers */,
 				92FF04AF23AA0BFB00AC97F6 /* Macro.h in Headers */,
 				4D9A936C26255BDA00F9B43C /* CoreMLRaster.hpp in Headers */,
@@ -2736,6 +2776,7 @@
 				EBECA39624643D320062C7A3 /* Arm82Eltwise.hpp in Headers */,
 				92FF033F23AA0B5A00AC97F6 /* CPUArgMax.hpp in Headers */,
 				92FF034C23AA0B5A00AC97F6 /* CPUSetDiff1D.hpp in Headers */,
+				19BCFD2226B10015001FCE93 /* MetalCache_generated.h in Headers */,
 				92FF02A123AA0B5A00AC97F6 /* CPUDepthwiseConvInt8.hpp in Headers */,
 				92FF036723AA0B5A00AC97F6 /* CPURuntime.hpp in Headers */,
 				92FF026623AA0B5A00AC97F6 /* CPUProposal.hpp in Headers */,
@@ -2777,6 +2818,7 @@
 				92FF038C23AA0B5A00AC97F6 /* CPUEltwise.hpp in Headers */,
 				92FF028823AA0B5A00AC97F6 /* CPUDequantize.hpp in Headers */,
 				481C2DF125FE2CD6001ED6DF /* Arm82OptFunc.hpp in Headers */,
+				4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */,
 				C43C8225251894F400A0FF84 /* WingoradGenerater.hpp in Headers */,
 				489D7A6A2550FDC800AD896A /* MetalConvolutionGEMM.hpp in Headers */,
 			);
@@ -3047,6 +3089,7 @@
 				48FB9DC924A848D0008E1A2D /* MNNPackedMatMulRemain.S in Sources */,
 				92FF044023AA0B7100AC97F6 /* ShapeSlice.cpp in Sources */,
 				92FF044723AA0B7100AC97F6 /* ShapeSqueeze.cpp in Sources */,
+				4AF4FB2A269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */,
 				92FF033923AA0B5A00AC97F6 /* MNNGemmint8to32_8x4_Unit.S in Sources */,
 				4896D36925FE2A3D00717702 /* Arm82Unary.cpp in Sources */,
 				92FF043423AA0B7100AC97F6 /* ShapeStridedSlice.cpp in Sources */,
@@ -3142,9 +3185,11 @@
 				48747D64245D9E33000B9709 /* GeometryTile.cpp in Sources */,
 				92FF043723AA0B7100AC97F6 /* ShapeDetectionOutput.cpp in Sources */,
 				92FF042623AA0B7100AC97F6 /* ShapeCosineSimilarity.cpp in Sources */,
+				4AF4FB2D269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */,
 				92FF02DC23AA0B5A00AC97F6 /* MNNReluInt8.S in Sources */,
 				92FF041A23AA0B7100AC97F6 /* ShapeFill.cpp in Sources */,
 				EB45C776244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */,
+				4AF4FB29269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */,
 				4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */,
 				11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */,
 				92FF035323AA0B5A00AC97F6 /* CPUScatterNd.cpp in Sources */,
@@ -3171,6 +3216,7 @@
 				4D6D7FC7265688E200F80814 /* MNNPackC4ForMatMul_A_BF16.S in Sources */,
 				92FF044923AA0B7100AC97F6 /* ShapeGatherND.cpp in Sources */,
 				489D7AB32550FDC900AD896A /* MetalPReLU.mm in Sources */,
+				4AF4FB24269ED235005BA97B /* SparseConvInt8TiledExecutor.cpp in Sources */,
 				489D7AB12550FDC900AD896A /* MetalDefine.metal in Sources */,
 				48FB9DCE24AB080C008E1A2D /* MNNPackC8.S in Sources */,
 				4D9A937A26255BDA00F9B43C /* CoreMLActivation.cpp in Sources */,
@@ -3219,6 +3265,7 @@
 				92FF039623AA0B5A00AC97F6 /* CPUDepthwiseConvInt8.cpp in Sources */,
 				92FF04AA23AA0BFB00AC97F6 /* BufferAllocator.cpp in Sources */,
 				92FF030F23AA0B5A00AC97F6 /* MNNPackC4.S in Sources */,
+				4AF4FB2E269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */,
 				92FF031D23AA0B5A00AC97F6 /* MNNConvRunForLineDepthWiseUint8.S in Sources */,
 				C43C81FA251894A600A0FF84 /* CommonOptFunctionNeon.cpp in Sources */,
 				92FF030123AA0B5A00AC97F6 /* MNNAddC4WithStride.S in Sources */,
@@ -3358,6 +3405,7 @@
 			buildActionMask = 2147483647;
 			files = (
 				92A4E0FC21F05A4F000B0919 /* MemoryUtilsTest.cpp in Sources */,
+				4A5BEC6426AAB4B30032F6BD /* ModuleTest.cpp in Sources */,
 				48FD03462467C64700456AF5 /* MatMulSpeed.cpp in Sources */,
 				4882C8F1241A24D900DAC168 /* PadTest.cpp in Sources */,
 				920004B521EDBDF600BCE892 /* BinaryOPTest.cpp in Sources */,
@@ -3655,7 +3703,7 @@
 				CODE_SIGN_STYLE = Manual;
 				DEAD_CODE_STRIPPING = YES;
 				DEFINES_MODULE = YES;
-				DEVELOPMENT_TEAM = 6G7464HHUS;
+				DEVELOPMENT_TEAM = YFX4Y9599X;
 				DYLIB_COMPATIBILITY_VERSION = 1;
 				DYLIB_CURRENT_VERSION = 1;
 				DYLIB_INSTALL_NAME_BASE = "@rpath";
@@ -3715,7 +3763,7 @@
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
 				CODE_SIGN_STYLE = Automatic;
-				DEVELOPMENT_TEAM = 6G7464HHUS;
+				DEVELOPMENT_TEAM = YFX4Y9599X;
 				GCC_ENABLE_CPP_EXCEPTIONS = NO;
 				GCC_ENABLE_CPP_RTTI = NO;
 				HEADER_SEARCH_PATHS = (
@@ -3728,7 +3776,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.cat.MNN.playgroundvvs33;
+				PRODUCT_BUNDLE_IDENTIFIER = com.tianbu.MNN.playgroundvvs33;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};
@@ -3740,7 +3788,7 @@
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
 				CODE_SIGN_STYLE = Automatic;
-				DEVELOPMENT_TEAM = 6G7464HHUS;
+				DEVELOPMENT_TEAM = YFX4Y9599X;
 				GCC_ENABLE_CPP_EXCEPTIONS = NO;
 				GCC_ENABLE_CPP_RTTI = NO;
 				HEADER_SEARCH_PATHS = (
@@ -3753,7 +3801,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.cat.MNN.playgroundvv;
+				PRODUCT_BUNDLE_IDENTIFIER = com.tianbu.MNN.playgroundvv;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};

+ 19 - 3
pymnn/CMakeLists.txt

@@ -3,6 +3,7 @@
 cmake_minimum_required(VERSION 3.4.1)
 project(mnnpybridge)
 
+option(DEPEND_AAPL_FMWK "use dependency library .framework instead of traditional .a/.dylib" OFF)
 option(MNN_BUILD_SHARED_LIBS "MNN build shared or static lib" ON)
 option(MNN_WIN_RUNTIME_MT "MNN use /MT on Windows dll" OFF)
 option(PYMNN_USE_ALINNPYTHON "based on AliNNPython" ON)
@@ -109,17 +110,32 @@ if(WIN32 OR APPLE OR CMAKE_SYSTEM_NAME MATCHES "^Linux")
 
     target_include_directories(mnnpybridge PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src ${DEPEND_PATH}/MNN/include)
     target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/MNN/lib/${LIB_SUBPATH})
-    target_link_libraries(mnnpybridge PRIVATE MNN)
+    if(APPLE AND DEPEND_AAPL_FMWK)
+        target_link_libraries(mnnpybridge PRIVATE "-framework MNN")
+        set_target_properties(mnnpybridge PROPERTIES LINK_FLAGS "-Wl,-F${DEPEND_PATH}/MNN/lib/${LIB_SUBPATH}")
+    else()
+        target_link_libraries(mnnpybridge PRIVATE MNN)
+    endif()
 
     if(PYMNN_USE_ALINNPYTHON)
         target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/include)
         target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/lib/${LIB_SUBPATH})
-        target_link_libraries(mnnpybridge PRIVATE python)
+        if(APPLE AND DEPEND_AAPL_FMWK)
+            target_link_libraries(mnnpybridge PRIVATE "-framework python")
+            set_target_properties(mnnpybridge PROPERTIES LINK_FLAGS "-Wl,-F${DEPEND_PATH}/AliNNPython/lib/${LIB_SUBPATH}")
+        else()
+            target_link_libraries(mnnpybridge PRIVATE python)
+        endif()
     endif()
     if(PYMNN_NUMPY_USABLE)
         target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/include)
         target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/lib/${LIB_SUBPATH})
-        target_link_libraries(mnnpybridge PRIVATE numpy_python)
+        if(APPLE AND DEPEND_AAPL_FMWK)
+            target_link_libraries(mnnpybridge PRIVATE "-framework numpy_python")
+            set_target_properties(mnnpybridge PROPERTIES LINK_FLAGS "-Wl,-F${DEPEND_PATH}/numpy/lib/${LIB_SUBPATH}")
+        else()
+            target_link_libraries(mnnpybridge PRIVATE numpy_python)
+        endif()
     endif()
 else()
     target_include_directories(mnnpybridge PRIVATE ${MNN_DIR}/pymnn/src ${MNN_DIR}/pymnn/android/src/main/c/include)

+ 22 - 6
schema/current/CaffeOp_generated.h

@@ -694,12 +694,14 @@ struct Convolution3DCommonT : public flatbuffers::NativeTable {
   int32_t outputCount;
   bool relu;
   bool relu6;
+  int32_t group;
   Convolution3DCommonT()
       : padMode(PadMode_CAFFE),
         inputCount(0),
         outputCount(0),
         relu(false),
-        relu6(false) {
+        relu6(false),
+        group(1) {
   }
 };
 
@@ -735,6 +737,9 @@ struct Convolution3DCommon FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table
   bool relu6() const {
     return GetField<uint8_t>(20, 0) != 0;
   }
+  int32_t group() const {
+    return GetField<int32_t>(22, 1);
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, 4) &&
@@ -750,6 +755,7 @@ struct Convolution3DCommon FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table
            VerifyField<int32_t>(verifier, 16) &&
            VerifyField<uint8_t>(verifier, 18) &&
            VerifyField<uint8_t>(verifier, 20) &&
+           VerifyField<int32_t>(verifier, 22) &&
            verifier.EndTable();
   }
   Convolution3DCommonT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -787,6 +793,9 @@ struct Convolution3DCommonBuilder {
   void add_relu6(bool relu6) {
     fbb_.AddElement<uint8_t>(20, static_cast<uint8_t>(relu6), 0);
   }
+  void add_group(int32_t group) {
+    fbb_.AddElement<int32_t>(22, group, 1);
+  }
   explicit Convolution3DCommonBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -809,8 +818,10 @@ inline flatbuffers::Offset<Convolution3DCommon> CreateConvolution3DCommon(
     int32_t inputCount = 0,
     int32_t outputCount = 0,
     bool relu = false,
-    bool relu6 = false) {
+    bool relu6 = false,
+    int32_t group = 1) {
   Convolution3DCommonBuilder builder_(_fbb);
+  builder_.add_group(group);
   builder_.add_outputCount(outputCount);
   builder_.add_inputCount(inputCount);
   builder_.add_pads(pads);
@@ -4065,6 +4076,7 @@ inline void Convolution3DCommon::UnPackTo(Convolution3DCommonT *_o, const flatbu
   { auto _e = outputCount(); _o->outputCount = _e; };
   { auto _e = relu(); _o->relu = _e; };
   { auto _e = relu6(); _o->relu6 = _e; };
+  { auto _e = group(); _o->group = _e; };
 }
 
 inline flatbuffers::Offset<Convolution3DCommon> Convolution3DCommon::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Convolution3DCommonT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -4084,6 +4096,7 @@ inline flatbuffers::Offset<Convolution3DCommon> CreateConvolution3DCommon(flatbu
   auto _outputCount = _o->outputCount;
   auto _relu = _o->relu;
   auto _relu6 = _o->relu6;
+  auto _group = _o->group;
   return MNN::CreateConvolution3DCommon(
       _fbb,
       _dilates,
@@ -4094,7 +4107,8 @@ inline flatbuffers::Offset<Convolution3DCommon> CreateConvolution3DCommon(flatbu
       _inputCount,
       _outputCount,
       _relu,
-      _relu6);
+      _relu6,
+      _group);
 }
 
 inline SparseCommonT *SparseCommon::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -5540,7 +5554,8 @@ inline const flatbuffers::TypeTable *Convolution3DCommonTypeTable() {
     { flatbuffers::ET_INT, 0, -1 },
     { flatbuffers::ET_INT, 0, -1 },
     { flatbuffers::ET_BOOL, 0, -1 },
-    { flatbuffers::ET_BOOL, 0, -1 }
+    { flatbuffers::ET_BOOL, 0, -1 },
+    { flatbuffers::ET_INT, 0, -1 }
   };
   static const flatbuffers::TypeFunction type_refs[] = {
     PadModeTypeTable
@@ -5554,10 +5569,11 @@ inline const flatbuffers::TypeTable *Convolution3DCommonTypeTable() {
     "inputCount",
     "outputCount",
     "relu",
-    "relu6"
+    "relu6",
+    "group"
   };
   static const flatbuffers::TypeTable tt = {
-    flatbuffers::ST_TABLE, 9, type_codes, type_refs, nullptr, names
+    flatbuffers::ST_TABLE, 10, type_codes, type_refs, nullptr, names
   };
   return &tt;
 }

+ 9 - 0
schema/current/MNN_generated.h

@@ -3096,6 +3096,15 @@ struct LoopParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   static const flatbuffers::TypeTable *MiniReflectTypeTable() {
     return LoopParamTypeTable();
   }
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_TENSORNUMBER = 4,
+    VT_OUTPUTINDEXES = 6,
+    VT_INPUTINDEXES = 8,
+    VT_MIDTENSORS = 10,
+    VT_PARALLEL = 12,
+    VT_LOOPNUMBER = 14,
+    VT_COMMANDS = 16
+  };
   int32_t tensorNumber() const {
     return GetField<int32_t>(4, 0);
   }

+ 33 - 12
schema/current/TensorflowOp_generated.h

@@ -376,11 +376,12 @@ enum UnaryOpOperation {
   UnaryOpOperation_TANH = 30,
   UnaryOpOperation_HARDSWISH = 31,
   UnaryOpOperation_GELU = 32,
+  UnaryOpOperation_GELU_STANDARD = 33,
   UnaryOpOperation_MIN = UnaryOpOperation_ABS,
-  UnaryOpOperation_MAX = UnaryOpOperation_GELU
+  UnaryOpOperation_MAX = UnaryOpOperation_GELU_STANDARD
 };
 
-inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[33] {
+inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[34] {
   static const UnaryOpOperation values[] = {
     UnaryOpOperation_ABS,
     UnaryOpOperation_NEG,
@@ -414,7 +415,8 @@ inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[33] {
     UnaryOpOperation_SIGMOID,
     UnaryOpOperation_TANH,
     UnaryOpOperation_HARDSWISH,
-    UnaryOpOperation_GELU
+    UnaryOpOperation_GELU,
+    UnaryOpOperation_GELU_STANDARD
   };
   return values;
 }
@@ -454,13 +456,14 @@ inline const char * const *EnumNamesUnaryOpOperation() {
     "TANH",
     "HARDSWISH",
     "GELU",
+    "GELU_STANDARD",
     nullptr
   };
   return names;
 }
 
 inline const char *EnumNameUnaryOpOperation(UnaryOpOperation e) {
-  if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_GELU) return "";
+  if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_GELU_STANDARD) return "";
   const size_t index = static_cast<int>(e);
   return EnumNamesUnaryOpOperation()[index];
 }
@@ -3049,8 +3052,10 @@ struct LayerNormT : public flatbuffers::NativeTable {
   float epsilon;
   std::vector<float> gamma;
   std::vector<float> beta;
+  int32_t group;
   LayerNormT()
-      : epsilon(0.0f) {
+      : epsilon(0.0f),
+        group(1) {
   }
 };
 
@@ -3071,6 +3076,9 @@ struct LayerNorm FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   const flatbuffers::Vector<float> *beta() const {
     return GetPointer<const flatbuffers::Vector<float> *>(10);
   }
+  int32_t group() const {
+    return GetField<int32_t>(12, 1);
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, 4) &&
@@ -3080,6 +3088,7 @@ struct LayerNorm FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            verifier.VerifyVector(gamma()) &&
            VerifyOffset(verifier, 10) &&
            verifier.VerifyVector(beta()) &&
+           VerifyField<int32_t>(verifier, 12) &&
            verifier.EndTable();
   }
   LayerNormT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3102,6 +3111,9 @@ struct LayerNormBuilder {
   void add_beta(flatbuffers::Offset<flatbuffers::Vector<float>> beta) {
     fbb_.AddOffset(10, beta);
   }
+  void add_group(int32_t group) {
+    fbb_.AddElement<int32_t>(12, group, 1);
+  }
   explicit LayerNormBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -3119,8 +3131,10 @@ inline flatbuffers::Offset<LayerNorm> CreateLayerNorm(
     flatbuffers::Offset<flatbuffers::Vector<int32_t>> axis = 0,
     float epsilon = 0.0f,
     flatbuffers::Offset<flatbuffers::Vector<float>> gamma = 0,
-    flatbuffers::Offset<flatbuffers::Vector<float>> beta = 0) {
+    flatbuffers::Offset<flatbuffers::Vector<float>> beta = 0,
+    int32_t group = 1) {
   LayerNormBuilder builder_(_fbb);
+  builder_.add_group(group);
   builder_.add_beta(beta);
   builder_.add_gamma(gamma);
   builder_.add_epsilon(epsilon);
@@ -4465,6 +4479,7 @@ inline void LayerNorm::UnPackTo(LayerNormT *_o, const flatbuffers::resolver_func
   { auto _e = epsilon(); _o->epsilon = _e; };
   { auto _e = gamma(); if (_e) { _o->gamma.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->gamma[_i] = _e->Get(_i); } } };
   { auto _e = beta(); if (_e) { _o->beta.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->beta[_i] = _e->Get(_i); } } };
+  { auto _e = group(); _o->group = _e; };
 }
 
 inline flatbuffers::Offset<LayerNorm> LayerNorm::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LayerNormT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -4479,12 +4494,14 @@ inline flatbuffers::Offset<LayerNorm> CreateLayerNorm(flatbuffers::FlatBufferBui
   auto _epsilon = _o->epsilon;
   auto _gamma = _o->gamma.size() ? _fbb.CreateVector(_o->gamma) : 0;
   auto _beta = _o->beta.size() ? _fbb.CreateVector(_o->beta) : 0;
+  auto _group = _o->group;
   return MNN::CreateLayerNorm(
       _fbb,
       _axis,
       _epsilon,
       _gamma,
-      _beta);
+      _beta,
+      _group);
 }
 
 inline RandomUniformT *RandomUniform::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -4717,6 +4734,7 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
     { flatbuffers::ET_INT, 0, 0 },
     { flatbuffers::ET_INT, 0, 0 },
     { flatbuffers::ET_INT, 0, 0 },
+    { flatbuffers::ET_INT, 0, 0 },
     { flatbuffers::ET_INT, 0, 0 }
   };
   static const flatbuffers::TypeFunction type_refs[] = {
@@ -4755,10 +4773,11 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
     "SIGMOID",
     "TANH",
     "HARDSWISH",
-    "GELU"
+    "GELU",
+    "GELU_STANDARD"
   };
   static const flatbuffers::TypeTable tt = {
-    flatbuffers::ST_ENUM, 33, type_codes, type_refs, nullptr, names
+    flatbuffers::ST_ENUM, 34, type_codes, type_refs, nullptr, names
   };
   return &tt;
 }
@@ -5446,16 +5465,18 @@ inline const flatbuffers::TypeTable *LayerNormTypeTable() {
     { flatbuffers::ET_INT, 1, -1 },
     { flatbuffers::ET_FLOAT, 0, -1 },
     { flatbuffers::ET_FLOAT, 1, -1 },
-    { flatbuffers::ET_FLOAT, 1, -1 }
+    { flatbuffers::ET_FLOAT, 1, -1 },
+    { flatbuffers::ET_INT, 0, -1 }
   };
   static const char * const names[] = {
     "axis",
     "epsilon",
     "gamma",
-    "beta"
+    "beta",
+    "group"
   };
   static const flatbuffers::TypeTable tt = {
-    flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names
+    flatbuffers::ST_TABLE, 5, type_codes, nullptr, nullptr, names
   };
   return &tt;
 }

+ 1 - 0
schema/default/CaffeOp.fbs

@@ -36,6 +36,7 @@ table Convolution3DCommon {
     outputCount:int = 0;
     relu:bool = false;
     relu6:bool = false;
+    group:int = 1;
 }
 
 enum SparseAlgo : byte {

+ 2 - 0
schema/default/TensorflowOp.fbs

@@ -141,6 +141,7 @@ enum UnaryOpOperation : int {
     TANH = 30,
     HARDSWISH = 31,
     GELU = 32,
+    GELU_STANDARD = 33,
 }
 
 table UnaryOp {
@@ -295,6 +296,7 @@ table LayerNorm {
     epsilon: float;
     gamma: [float];
     beta: [float];
+    group: int = 1;
 }
 table RandomUniform {
     seed:int = 0;

+ 5 - 3
source/backend/arm82/Arm82Functions.cpp

@@ -591,11 +591,11 @@ bool Arm82Functions::init() {
     using Vec = MNN::Math::Vec<FLOAT16, 8>;
 #define FUNC_PTR_ASSIGN(dst, src) dst = (decltype(dst))(src)
     gInstance = new CoreFunctions;
-    
+
     FUNC_PTR_ASSIGN(gInstance->MNNFp32ToLowp, MNNQuantizeFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNLowpToFp32, MNNDequantizeFP16);
     gInstance->bytes = 2;
-    
+
     // Packed
     gInstance->pack = 8;
     FUNC_PTR_ASSIGN(gInstance->MNNPackCUnit, MNNPackC8FP16);
@@ -617,7 +617,7 @@ bool Arm82Functions::init() {
     FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterpFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNCopyC4WithStride, MNNCopyC8WithStrideFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
-    
+
     // MatMul
     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
@@ -629,6 +629,8 @@ bool Arm82Functions::init() {
 
     FUNC_PTR_ASSIGN(gInstance->chooseWinoSourceTransform, Arm82WinogradFunction::chooseSourceTransform);
     FUNC_PTR_ASSIGN(gInstance->chooseWinoDestTransform, Arm82WinogradFunction::chooseDestTransform);
+    FUNC_PTR_ASSIGN(gInstance->chooseWinoSourceTransformPack, Arm82WinogradFunction::chooseWinoSourceTransformPack);
+
 
     gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise;
     gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise;

+ 10 - 7
source/backend/arm82/Arm82Unary.cpp

@@ -110,19 +110,22 @@ struct _Exp {
     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
         auto out = (float*)outRaw;
         auto inp = (const float*)inpRaw;
-        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-        MNNExp(out, out, realSize);
+        float offset[2] = {
+            1.0f,
+            0.0f
+        };
+        MNNExp(out, inp, offset, realSize);
     }
 };
 struct _ExpM1 {
     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
         auto out = (float*)outRaw;
         auto inp = (const float*)inpRaw;
-        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-        MNNExp(out, out, realSize);
-        for (int i=0; i<realSize; ++i) {
-            out[i] = out[i] - 1.0f;
-        }
+        float offset[2] = {
+            1.0f,
+            -1.0f
+        };
+        MNNExp(out, inp, offset, realSize);
     }
 };
 

+ 167 - 0
source/backend/arm82/Arm82Vec.hpp

@@ -108,6 +108,173 @@ struct Vec<FLOAT16, 8> {
         VecType dst = { vnegq_f16(value) };
         return dst;
     }
+
+    static inline void transpose12(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3, VecType& vec4,
+                                   VecType& vec5, VecType& vec6, VecType& vec7, VecType& vec8, VecType& vec9,
+                                   VecType& vec10, VecType& vec11) {
+
+#ifdef __aarch64__
+        auto tmp1 = vzipq_s16(vec0.value, vec1.value); // tmp1 would disappear after compile
+        auto v21 = tmp1.val[0];
+        auto v22 = tmp1.val[1];
+        auto tmp2 = vzipq_s16(vec2.value, vec3.value);
+        auto v24 = tmp2.val[0];
+        auto v25 = tmp2.val[1];
+        auto tmp3 = vzipq_s16(vec4.value, vec5.value);
+        auto v27 = tmp3.val[0];
+        auto v28 = tmp3.val[1];
+        auto tmp4 = vzipq_s16(vec6.value, vec7.value);
+        auto v30 = tmp4.val[0];
+        auto v31 = tmp4.val[1];
+
+        auto tmp5 = vzipq_s32(v21, v24);
+        vec0.value = tmp5.val[0];
+        vec1.value = tmp5.val[1];
+        auto tmp6 = vzipq_s32(v22, v25);
+        vec2.value = tmp6.val[0];
+        vec3.value = tmp6.val[1];
+        auto tmp7 = vzipq_s32(v27, v30);
+        vec4.value = tmp7.val[0];
+        vec5.value = tmp7.val[1];
+        auto tmp8 = vzipq_s32(v28, v31);
+        vec6.value = tmp8.val[0];
+        vec7.value = tmp8.val[1];
+        auto v20 = vtrn1q_s64(vec0.value, vec4.value);
+        auto v12 = vtrn2q_s64(vec0.value, vec4.value);
+        auto v23 = vtrn1q_s64(vec1.value, vec5.value);
+        auto v13 = vtrn2q_s64(vec1.value, vec5.value);
+        auto v26 = vtrn1q_s64(vec2.value, vec6.value);
+        auto v14 = vtrn2q_s64(vec2.value, vec6.value);
+        auto v29 = vtrn1q_s64(vec3.value, vec7.value);
+        auto v15 = vtrn2q_s64(vec3.value, vec7.value);
+
+        auto tmp9 = vzipq_s16(vec8.value, vec9.value); // tmp9 would disappear after compile
+        vec0.value = tmp9.val[0];
+        vec1.value = tmp9.val[1];
+        auto tmp10 = vzipq_s16(vec10.value, vec11.value);
+        vec2.value = tmp10.val[0];
+        vec3.value = tmp10.val[1];
+        auto tmp11 = vzipq_s32(vec0.value, vec2.value);
+        auto v16 = tmp11.val[0];
+        auto v17 = tmp11.val[1];
+        auto tmp12 = vzipq_s32(vec1.value, vec3.value);
+        auto v18 = tmp12.val[0];
+        auto v19 = tmp12.val[1];
+
+        v21 = vtrn1q_s64(v16, v12);
+        v22 = vtrn2q_s64(v12, v16);
+        v24 = vtrn1q_s64(v17, v13);
+        v25 = vtrn2q_s64(v13, v17);
+        v27 = vtrn1q_s64(v18, v14);
+        v28 = vtrn2q_s64(v14, v18);
+        v30 = vtrn1q_s64(v19, v15);
+        v31 = vtrn2q_s64(v15, v19);
+
+        vec0.value  = v20;
+        vec1.value  = v21;
+        vec2.value  = v22;
+        vec3.value  = v23;
+        vec4.value  = v24;
+        vec5.value  = v25;
+        vec6.value  = v26;
+        vec7.value  = v27;
+        vec8.value  = v28;
+        vec9.value  = v29;
+        vec10.value = v30;
+        vec11.value = v31;
+#else
+
+        auto tmp1 = vzipq_s16(vec0.value, vec1.value); // tmp1 would disappear after compile
+        auto v21 = tmp1.val[0];
+        auto v22 = tmp1.val[1];
+        auto tmp2 = vzipq_s16(vec2.value, vec3.value);
+        auto v24 = tmp2.val[0];
+        auto v25 = tmp2.val[1];
+        auto tmp3 = vzipq_s16(vec4.value, vec5.value);
+        auto v27 = tmp3.val[0];
+        auto v28 = tmp3.val[1];
+        auto tmp4 = vzipq_s16(vec6.value, vec7.value);
+        auto v30 = tmp4.val[0];
+        auto v31 = tmp4.val[1];
+
+        auto tmp5 = vzipq_s32(v21, v24);
+        vec0.value = tmp5.val[0];
+        vec1.value = tmp5.val[1];
+        auto tmp6 = vzipq_s32(v22, v25);
+        vec2.value = tmp6.val[0];
+        vec3.value = tmp6.val[1];
+        auto tmp7 = vzipq_s32(v27, v30);
+        vec4.value = tmp7.val[0];
+        vec5.value = tmp7.val[1];
+        auto tmp8 = vzipq_s32(v28, v31);
+        vec6.value = tmp8.val[0];
+        vec7.value = tmp8.val[1];
+
+
+        auto v20 = vec0.value;
+        auto v12 = vec4.value;
+        v20 = vsetq_lane_s64(vgetq_lane_s64(vec4.value, 0), v20, 1);
+        v12 = vsetq_lane_s64(vgetq_lane_s64(vec0.value, 1), v12, 0);
+        auto v23 = vec1.value;
+        auto v13 = vec5.value;
+        v23 = vsetq_lane_s64(vgetq_lane_s64(vec5.value, 0), v23, 1);
+        v13 = vsetq_lane_s64(vgetq_lane_s64(vec1.value, 1), v13, 0);
+        auto v26 = vec2.value;
+        auto v14 = vec6.value;
+        v26 = vsetq_lane_s64(vgetq_lane_s64(vec6.value, 0), v26, 1);
+        v14 = vsetq_lane_s64(vgetq_lane_s64(vec2.value, 1), v14, 0);
+        auto v29 = vec3.value;
+        auto v15 = vec7.value;
+        v29 = vsetq_lane_s64(vgetq_lane_s64(vec7.value, 0), v29, 1);
+        v15 = vsetq_lane_s64(vgetq_lane_s64(vec3.value, 1), v15, 0);
+
+
+        auto tmp9 = vzipq_s16(vec8.value, vec9.value); // tmp9 would disappear after compile
+        vec0.value = tmp9.val[0];
+        vec1.value = tmp9.val[1];
+        auto tmp10 = vzipq_s16(vec10.value, vec11.value);
+        vec2.value = tmp10.val[0];
+        vec3.value = tmp10.val[1];
+        auto tmp11 = vzipq_s32(vec0.value, vec2.value);
+        auto v16 = tmp11.val[0];
+        auto v17 = tmp11.val[1];
+        auto tmp12 = vzipq_s32(vec1.value, vec3.value);
+        auto v18 = tmp12.val[0];
+        auto v19 = tmp12.val[1];
+
+        v21 = v16;
+        v22 = v16;
+        v21 = vsetq_lane_s64(vgetq_lane_s64(v12, 0), v21, 1);
+        v22 = vsetq_lane_s64(vgetq_lane_s64(v12, 1), v22, 0);
+        v24 = v17;
+        v25 = v17;
+        v24 = vsetq_lane_s64(vgetq_lane_s64(v13, 0), v24, 1);
+        v25 = vsetq_lane_s64(vgetq_lane_s64(v13, 1), v25, 0);
+        v27 = v18;
+        v28 = v18;
+        v27 = vsetq_lane_s64(vgetq_lane_s64(v14, 0), v27, 1);
+        v28 = vsetq_lane_s64(vgetq_lane_s64(v14, 1), v28, 0);
+        v30 = v19;
+        v31 = v19;
+        v30 = vsetq_lane_s64(vgetq_lane_s64(v15, 0), v30, 1);
+        v31 = vsetq_lane_s64(vgetq_lane_s64(v15, 1), v31, 0);
+
+        vec0.value  = v20;
+        vec1.value  = v21;
+        vec2.value  = v22;
+        vec3.value  = v23;
+        vec4.value  = v24;
+        vec5.value  = v25;
+        vec6.value  = v26;
+        vec7.value  = v27;
+        vec8.value  = v28;
+        vec9.value  = v29;
+        vec10.value = v30;
+        vec11.value = v31;
+
+#endif
+
+    }
 };
 } // namespace Math
 } // namespace MNN

+ 351 - 0
source/backend/arm82/Arm82WinogradOptFunc.cpp

@@ -12,12 +12,339 @@
 #include "Arm82OptFunc.hpp"
 #include <cstring>
 #include <memory>
+#include <map>
 #include "core/Macro.h"
 #include "math/Vec.hpp"
 using Vec = MNN::Math::Vec<FLOAT16, 8>;
+using VecType = Vec;
+using ElementType = FLOAT16;
+
+#define TRANSPOSE_12X8_SAVE()                                               \
+    VecType v0  = VecType::load(srcPtr + 0 * packCUnit);                    \
+    VecType v1  = VecType::load(srcPtr + 1 * packCUnit);                    \
+    VecType v2  = VecType::load(srcPtr + 2 * packCUnit);                    \
+    VecType v3  = VecType::load(srcPtr + 3 * packCUnit);                    \
+    VecType v4  = VecType::load(srcPtr + 4 * packCUnit);                    \
+    VecType v5  = VecType::load(srcPtr + 5 * packCUnit);                    \
+    VecType v6  = VecType::load(srcPtr + 6 * packCUnit);                    \
+    VecType v7  = VecType::load(srcPtr + 7 * packCUnit);                    \
+    VecType v8  = VecType::load(srcPtr + 8 * packCUnit);                    \
+    VecType v9  = VecType::load(srcPtr + 9 * packCUnit);                    \
+    VecType v10 = VecType::load(srcPtr + 10 * packCUnit);                   \
+    VecType v11 = VecType::load(srcPtr + 11 * packCUnit);                   \
+    VecType::transpose12(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11); \
+    VecType::save(srcPtr + 0 * packCUnit, v0);                              \
+    VecType::save(srcPtr + 1 * packCUnit, v1);                              \
+    VecType::save(srcPtr + 2 * packCUnit, v2);                              \
+    VecType::save(srcPtr + 3 * packCUnit, v3);                              \
+    VecType::save(srcPtr + 4 * packCUnit, v4);                              \
+    VecType::save(srcPtr + 5 * packCUnit, v5);                              \
+    VecType::save(srcPtr + 6 * packCUnit, v6);                              \
+    VecType::save(srcPtr + 7 * packCUnit, v7);                              \
+    VecType::save(srcPtr + 8 * packCUnit, v8);                              \
+    VecType::save(srcPtr + 9 * packCUnit, v9);                              \
+    VecType::save(srcPtr + 10 * packCUnit, v10);                            \
+    VecType::save(srcPtr + 11 * packCUnit, v11);
 
 namespace MNN {
 
+static void _sourceTransformUnit4x4Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+    // register number: (srcUnit + 1) * EPack/packCUnit
+    constexpr int Nh = 4; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 8;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        TRANSPOSE_12X8_SAVE();
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit >> 1; ++i4c) // calculate 2 line in 8 packCUnit at once
+    {
+        // source transform D * B. register number : srcUnit * (EPack/4 + 1)
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        // dstStep =  ePack * pack * ic_4
+        auto ep0 = s00 - s20;
+        auto ep1 = s01 - s21;
+        auto ep2 = s02 - s22;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 + s20;
+        ep1 = s11 + s21;
+        ep2 = s12 + s22;
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 - s10;
+        ep1 = s21 - s11;
+        ep2 = s22 - s12;
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 - s10;
+        ep1 = s31 - s11;
+        ep2 = s32 - s12;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        // VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, s00);
+        // VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, s01);
+        // VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, s02);
+
+        // VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, s10);
+        // VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, s11);
+        // VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, s12);
+
+        // VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, s20);
+        // VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, s21);
+        // VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, s22);
+
+        // VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, s30);
+        // VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, s31);
+        // VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, s32);
+
+        // MNN_PRINT("\nwinograd in BT*D*B, iNh:0-3, i4c:%d\n", i4c);
+        // formatMatrix(dstPtr + 0 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 1 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 2 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 3 * dstStep , {ePack});
+
+        srcPtr += ePack << 1;
+        dstPtr += ePack << 1;
+    }
+}
+
+static void _sourceTransformUnit8x8Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/packCUnit = 27
+    // todo: impliment
+    constexpr int Nh = 8; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 8;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        TRANSPOSE_12X8_SAVE();
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit >> 1; ++i4c)
+    {
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s40 = VecType::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        VecType s41 = VecType::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        VecType s42 = VecType::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s50 = VecType::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        VecType s51 = VecType::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        VecType s52 = VecType::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s60 = VecType::load(srcPtr + 6 * loadTransposeStride + 0 * packCUnit);
+        VecType s61 = VecType::load(srcPtr + 6 * loadTransposeStride + 1 * packCUnit);
+        VecType s62 = VecType::load(srcPtr + 6 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s70 = VecType::load(srcPtr + 7 * loadTransposeStride + 0 * packCUnit);
+        VecType s71 = VecType::load(srcPtr + 7 * loadTransposeStride + 1 * packCUnit);
+        VecType s72 = VecType::load(srcPtr + 7 * loadTransposeStride + 2 * packCUnit);
+
+
+        // to-try: reorder complicated commpute of 8x8
+        auto ep0 = s00 * 36.f - s20 * 49.f + s40 * 14.f - s60;
+        auto ep1 = s01 * 36.f - s21 * 49.f + s41 * 14.f - s61;
+        auto ep2 = s02 * 36.f - s22 * 49.f + s42 * 14.f - s62;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * 36.f - (s30 + s40) * 13.f + (s50 + s60);
+        ep1 = (s11 + s21) * 36.f - (s31 + s41) * 13.f + (s51 + s61);
+        ep2 = (s12 + s22) * 36.f - (s32 + s42) * 13.f + (s52 + s62);
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s20 - s10) * 36.f + (s30 - s40) * 13.f + (s60 - s50);
+        ep1 = (s21 - s11) * 36.f + (s31 - s41) * 13.f + (s61 - s51);
+        ep2 = (s22 - s12) * 36.f + (s32 - s42) * 13.f + (s62 - s52);
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 18.f + s20 * 9.f - s30 * 20.f - s40 * 10.f + s50 * 2.f + s60;
+        ep1 = s11 * 18.f + s21 * 9.f - s31 * 20.f - s41 * 10.f + s51 * 2.f + s61;
+        ep2 = s12 * 18.f + s22 * 9.f - s32 * 20.f - s42 * 10.f + s52 * 2.f + s62;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 9.f - s10 * 18.f + s30 * 20.f - s40 * 10.f - s50 * 2.f + s60;
+        ep1 = s21 * 9.f - s11 * 18.f + s31 * 20.f - s41 * 10.f - s51 * 2.f + s61;
+        ep2 = s22 * 9.f - s12 * 18.f + s32 * 20.f - s42 * 10.f - s52 * 2.f + s62;
+        VecType::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 12.f + s20 * 4.f - s30 * 15.f - s40 * 5.f + s50 * 3.f + s60;
+        ep1 = s11 * 12.f + s21 * 4.f - s31 * 15.f - s41 * 5.f + s51 * 3.f + s61;
+        ep2 = s12 * 12.f + s22 * 4.f - s32 * 15.f - s42 * 5.f + s52 * 3.f + s62;
+        VecType::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 4.f - s10 * 12.f + s30 * 15.f - s40 * 5.f - s50 * 3.f + s60;
+        ep1 = s21 * 4.f - s11 * 12.f + s31 * 15.f - s41 * 5.f - s51 * 3.f + s61;
+        ep2 = s22 * 4.f - s12 * 12.f + s32 * 15.f - s42 * 5.f - s52 * 3.f + s62;
+        VecType::save(dstPtr + 6 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 6 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 6 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 * 49.f - s10 * 36.f - s50 * 14.f + s70;
+        ep1 = s31 * 49.f - s11 * 36.f - s51 * 14.f + s71;
+        ep2 = s32 * 49.f - s12 * 36.f - s52 * 14.f + s72;
+        VecType::save(dstPtr + 7 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 7 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 7 * dstStep + 2 * packCUnit, ep2);
+        srcPtr += ePack << 1;
+        dstPtr += ePack << 1;
+    }
+}
+
+static void _sourceTransformUnit6x6Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/packCUnit
+    constexpr int Nh = 6; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 8;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        TRANSPOSE_12X8_SAVE();
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit >> 1; ++i4c)
+    {
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s40 = VecType::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        VecType s41 = VecType::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        VecType s42 = VecType::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s50 = VecType::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        VecType s51 = VecType::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        VecType s52 = VecType::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        // to-try: reorder
+        auto ep0 = s00 * 4.f - s20 * 5.f + s40;
+        auto ep1 = s01 * 4.f - s21 * 5.f + s41;
+        auto ep2 = s02 * 4.f - s22 * 5.f + s42;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * (-4.f) + s30 + s40;
+        ep1 = (s11 + s21) * (-4.f) + s31 + s41;
+        ep2 = (s12 + s22) * (-4.f) + s32 + s42;
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 - s20) * (4.f) + s40 - s30;
+        ep1 = (s11 - s21) * (4.f) + s41 - s31;
+        ep2 = (s12 - s22) * (4.f) + s42 - s32;
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * (-2.f) - s20 + s30 * 2.f + s40;
+        ep1 = s11 * (-2.f) - s21 + s31 * 2.f + s41;
+        ep2 = s12 * (-2.f) - s22 + s32 * 2.f + s42;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 2.f - s20 - s30 * 2.f + s40;
+        ep1 = s11 * 2.f - s21 - s31 * 2.f + s41;
+        ep2 = s12 * 2.f - s22 - s32 * 2.f + s42;
+        VecType::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 4.f - s30 * 5.f + s50;
+        ep1 = s11 * 4.f - s31 * 5.f + s51;
+        ep2 = s12 * 4.f - s32 * 5.f + s52;
+        VecType::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        srcPtr += ePack << 1;
+        dstPtr += ePack << 1;
+    }
+}
+
+
 static void _sourceTransformUnit4x4(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
     Vec s0 = Vec::load(srcBlock + 0 * srcStep);
     Vec s1 = Vec::load(srcBlock + 1 * srcStep);
@@ -172,6 +499,25 @@ static Arm82WinogradFunction::TransformFunc gProcUnit6[] = {
 };
 
 
+Arm82WinogradFunction::TransformPackFunc Arm82WinogradFunction::chooseWinoSourceTransformPack(int k, int w, int ePack, int lPack, int packCUnit) {
+    if (ePack == 12 && lPack == 1 && packCUnit == 8) {
+        if (k == 4 && w == 4) {
+            return _sourceTransformUnit4x4Pack12;
+        }
+        if (k == 6 && w == 6) {
+            return _sourceTransformUnit6x6Pack12;
+        }
+        if (k == 8 && w == 8) {
+            return _sourceTransformUnit8x8Pack12;
+        }
+        // other packing size
+    }
+    MNN_ERROR("Arm82WinogradFunction Can not find function for ePack:%d, packCUnit:%d\n", ePack, packCUnit);
+    MNN_ASSERT(false);
+    return nullptr;
+}
+
+
 Arm82WinogradFunction::TransformFunc Arm82WinogradFunction::chooseSourceTransform(int k, int w) {
     if (6 == k && 6 == w) {
         return _sourceTransformUnit6x6;
@@ -206,4 +552,9 @@ int Arm82MNNGetConvTileNumber() {
 }
 
 } // namespace MNN
+
+#undef TRANSPOSE_12X8_SAVE
+
 #endif
+
+

+ 3 - 2
source/backend/arm82/Arm82WinogradOptFunc.hpp

@@ -17,11 +17,12 @@ class Arm82WinogradFunction {
 public:
     typedef void (*TransformFunc)(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep);
     typedef void (*Int8TransFunc)(const int8_t* srcBlock, int8_t* dstStart, size_t srcStep, size_t dstStep);
-
+    typedef void (*TransformPackFunc)(FLOAT16* srcBlock, FLOAT16* dstStart, size_t dstStep);
     /*Use the generator with interp 0.5*/
     static TransformFunc chooseSourceTransform(int k, int w);
     static TransformFunc chooseDestTransform(int k, int h);
-    
+    static TransformPackFunc chooseWinoSourceTransformPack(int k, int h, int ePack, int lPack, int packCUnit);
+
     static Int8TransFunc chooseInt8SourceTransform(int k, int w);
     static TransformFunc chooseInt8DestTransform(int k, int h);
 };

+ 1 - 1
source/backend/cpu/BinaryUtils.hpp

@@ -48,7 +48,7 @@ struct BinaryRealDiv {
 template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 struct BinaryMod {
     _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
-        return x - x / y;
+        return x - (x / y) * y;
     }
 };
 

+ 3 - 4
source/backend/cpu/CPUBackend.cpp

@@ -178,7 +178,7 @@ bool CPUBackend::allocBuffer(int size, Tensor* dest, StorageType storageType) {
     // MNN_PRINT("Acquire size = %d\n", size);
     if (size <= 0) {
         MNN_PRINT("Acquire buffer size = %d\n", size);
-        MNN_ASSERT(false);
+//        MNN_ASSERT(false);
         return false;
     }
     // if (size > LARGE_MEMORY) {
@@ -457,10 +457,10 @@ void CPUBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor)
                     break;
             }
             wrapTensor.reset(Tensor::create(srcTensor->shape(), dstTensor->getType(), nullptr, dimType));
-            code = CPUCastCreator::cast(srcTensor, wrapTensor.get());
+            code = CPUCastCreator::cast(srcTensor, wrapTensor.get(), this);
             CPUTensorConverter::convert(wrapTensor.get(), dstTensor);
         } else {
-            code = CPUCastCreator::cast(srcTensor, dstTensor);
+            code = CPUCastCreator::cast(srcTensor, dstTensor, this);
         }
         if (NO_ERROR != code) {
             MNN_ERROR("Error in CPUBackend::onCopyBuffer:cast\n");
@@ -492,7 +492,6 @@ void registerCPURuntimeCreator() {
 #endif
     // TODO: Merge _initCoreFunction MNNFunctionInit and cpuinfo_arm_init
     MNNCoreFunctionInit();
-    MNNCoreInt8FunctionInit();
     MNNInsertExtraRuntimeCreator(MNN_FORWARD_CPU, new CPURuntimeCreator);
 };
 } // namespace MNN

+ 25 - 16
source/backend/cpu/CPUCast.cpp

@@ -15,24 +15,32 @@
 
 namespace MNN {
 ErrorCode CPUCastCreator::cast(void* const inputRaw, void* outputRaw, halide_type_t inputType, halide_type_t outputType,
-                               int number, float scale, float zero, float min, float max) {
-    int c4Size = number / 4;
-    int remain = c4Size * 4;
+                               int number, float scale, float zero, float min, float max, const CPUBackend* bn) {
+    auto pack = bn->functions()->pack;
+    int c4Size = number / pack;
+    int remain = number % pack;
     if (inputType == halide_type_of<float>() && outputType == halide_type_of<int8_t>()) {
         scale = (scale == 0.f ? 0.f : 1.f / scale);
-        std::vector<float> scales(4, scale);
-        MNNFloat2Int8(static_cast<float*>(inputRaw), static_cast<int8_t*>(outputRaw), c4Size, scales.data(), min, max, zero);
-        for (int i = remain; i < number; i++) {
-            float x = std::round(static_cast<float* const>(inputRaw)[i] * scale + zero);
-            static_cast<int8_t*>(outputRaw)[i] = static_cast<int8_t>(std::max(std::min(x, max), min));
+        std::vector<float> scales(pack, scale);
+        bn->int8Functions()->MNNFloat2Int8(static_cast<float*>(inputRaw), static_cast<int8_t*>(outputRaw), c4Size, scales.data(), min, max, zero);
+        if (remain > 0) {
+            std::vector<float> tempSrc(pack);
+            std::vector<int8_t> tempDst(pack);
+            ::memcpy(tempSrc.data(), static_cast<float* const>(inputRaw) + c4Size * pack, remain * sizeof(float));
+            bn->int8Functions()->MNNFloat2Int8(tempSrc.data(), tempDst.data(), 1, scales.data(), min, max, zero);
+            ::memcpy(static_cast<int8_t*>(outputRaw) + c4Size * pack, tempDst.data(), remain * sizeof(int8_t));
         }
         return NO_ERROR;
     }
     if (inputType == halide_type_of<int8_t>() && outputType == halide_type_of<float>()) {
-        std::vector<float> scales(4, scale);
-        MNNInt8ScaleToFloat(static_cast<float*>(outputRaw), static_cast<int8_t*>(inputRaw), scales.data(), c4Size, zero);
-        for (int i = remain; i < number; i++) {
-            static_cast<float*>(outputRaw)[i] = (static_cast<int8_t* const>(inputRaw)[i] - zero) * scale;
+        std::vector<float> scales(pack, scale);
+        bn->int8Functions()->MNNInt8ScaleToFloat(static_cast<float*>(outputRaw), static_cast<int8_t*>(inputRaw), scales.data(), c4Size, zero);
+        if (remain > 0) {
+            std::vector<float> tempDst(pack);
+            std::vector<int8_t> tempSrc(pack);
+            ::memcpy(tempSrc.data(), static_cast<int8_t* const>(inputRaw) + c4Size * pack, remain * sizeof(int8_t));
+            bn->int8Functions()->MNNInt8ScaleToFloat(tempDst.data(), tempSrc.data(), scales.data(), 1, zero);
+            ::memcpy(static_cast<float*>(outputRaw) + c4Size * pack, tempDst.data(), remain * sizeof(float));
         }
         return NO_ERROR;
     }
@@ -40,13 +48,15 @@ ErrorCode CPUCastCreator::cast(void* const inputRaw, void* outputRaw, halide_typ
     return NOT_SUPPORT;
 }
 
-ErrorCode CPUCastCreator::cast(const Tensor* input, const Tensor* output, int size) {
+ErrorCode CPUCastCreator::cast(const Tensor* input, const Tensor* output, const CPUBackend* bn) {
     auto srcT = input->getType();
     auto dstT = output->getType();
     auto ib     = input->buffer();
     auto ob     = output->buffer();
+    int totalSize = bn->getTensorSize(input);
+    auto bytes = ib.type.bytes();
     if (srcT == dstT) {
-        ::memcpy(ib.host, ob.host, input->size());
+        ::memcpy(ib.host, ob.host, totalSize * bytes);
         return NO_ERROR;
     }
     auto& quantAttr = TensorUtils::getDescribe(input)->quantAttr;
@@ -54,8 +64,7 @@ ErrorCode CPUCastCreator::cast(const Tensor* input, const Tensor* output, int si
         MNN_ERROR("No quant info for Cast\n");
         return INVALID_VALUE;
     }
-    int totalSize = size ? size : input->elementSize();
-    auto code = cast(ib.host, ob.host, srcT, dstT, totalSize, quantAttr->scale, quantAttr->zero, quantAttr->min, quantAttr->max);
+    auto code = cast(ib.host, ob.host, srcT, dstT, totalSize, quantAttr->scale, quantAttr->zero, quantAttr->min, quantAttr->max, bn);
     if (NO_ERROR != code) {
         MNN_ERROR("Error in CPUCast\n");
         return code;

+ 2 - 2
source/backend/cpu/CPUCast.hpp

@@ -16,8 +16,8 @@ class CPUCastCreator : public CPUBackend::Creator {
 public:
     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
                                 const MNN::Op* op, Backend* backend) const override;
-    static ErrorCode cast(const Tensor* input, const Tensor* output, int size = 0);
-    static ErrorCode cast(void* const inputRaw, void* outputRaw, halide_type_t inputType, halide_type_t outputType, int number, float scale, float zero, float min, float max);
+    static ErrorCode cast(const Tensor* input, const Tensor* output, const CPUBackend* bn);
+    static ErrorCode cast(void* const inputRaw, void* outputRaw, halide_type_t inputType, halide_type_t outputType, int number, float scale, float zero, float min, float max, const CPUBackend* bn);
 };
 } // namespace MNN
 #endif /* CPUCast_hpp */

+ 10 - 2
source/backend/cpu/CPUConvolution.cpp

@@ -19,6 +19,7 @@
 
 #include "backend/cpu/compute/ConvInt8Winograd.hpp"
 #include "backend/cpu/compute/ConvInt8TiledExecutor.hpp"
+#include "backend/cpu/compute/SparseConvInt8TiledExecutor.hpp"
 #ifdef MNN_USE_ONEDNN
 #include "backend/cpu/OneDNNConvInt8.hpp"
 #endif
@@ -191,7 +192,6 @@ std::shared_ptr<CPUConvolution::ResourceInt8> CPUConvolution::makeResourceInt8(B
 #endif
     auto weightDst = resource->mWeightInt8->host<int8_t>();
     memcpy(weightDst, weightSrc, resource->mWeightInt8->size());
-
     resource->mInputZeroPoint = convParam->symmetricQuan()->zeroPoint();
     resource->mOutputZeroPoint = convParam->symmetricQuan()->outputZeroPoint();
     resource->mClampMin = convParam->symmetricQuan()->clampMin();
@@ -289,13 +289,21 @@ public:
         quantAttr->max = (1<<(nbit-1))-1;
         TensorUtils::getDescribe(inputs[0])->quantAttr.reset(quantAttr);*/
         auto res = CPUConvolution::makeResourceInt8(backend, convOp, inputQuantInfo, outputQuantInfo);
+
+#ifdef MNN_USE_SPARSE_COMPUTE
+        auto core = static_cast<CPUBackend*>(backend)->int8Functions();
+        if (static_cast<CPUBackend*>(backend)->functions()->pack == 4 && convOp->sparseParameter() && SparseConvInt8TiledExecutor::shouldUseSparse(convOp)) {
+            return new SparseConvInt8TiledExecutor(backend, convOp, res);
+        }
+#endif
+
         if (!inputs.empty()) {
             std::vector<ConvInt8Winograd::UnitAttr> unitAttrs;
             if (ConvInt8Winograd::bestWinogradUnit(convOp, inputs[0], res->mWeightInt8.get(), outputs[0], backend, unitAttrs)) {
                 return new ConvInt8Winograd(backend, convOp, res, unitAttrs);
             }
         }
-        return new ConvInt8TiledExecutor(backend, convOp, res);
+        return new DenseConvInt8TiledExecutor(backend, convOp, res);
     }
 };
 

+ 2 - 2
source/backend/cpu/CPUConvolution.hpp

@@ -38,8 +38,8 @@ public:
         bool mRelu;
         int mActBits;
 
-        int8_t mInputZeroPoint;
-        int8_t mOutputZeroPoint;
+        int32_t mInputZeroPoint;
+        int32_t mOutputZeroPoint;
         int8_t mClampMin;
         int8_t mClampMax;
         Backend* backend;

+ 2 - 10
source/backend/cpu/CPUDepthwiseConvInt8.cpp

@@ -51,15 +51,6 @@ CPUDepthwiseConvInt8::CPUDepthwiseConvInt8(Backend* backend, const Convolution2D
     }
     mResource->mWeightInt8.swap(weight);
     backend->onReleaseBuffer(weight.get(), Backend::STATIC);
-    
-#ifdef MNN_USE_SSE
-    if (!mResource->offsets.empty()) {
-        for (int i = 0; i < outputCount; ++i) {
-            mResource->mBiasInt32->host<int32_t>()[i] -= mResource->offsets[i];
-        }
-    }
-    mResource->offsets.clear();
-#endif
 }
 
 CPUDepthwiseConvInt8::CPUDepthwiseConvInt8(Backend* backend, const Convolution2DCommon* common, const CPUDepthwiseConvInt8& exe) : CPUConvolution(common, backend), mResource(exe.mResource) {
@@ -165,10 +156,11 @@ ErrorCode CPUDepthwiseConvInt8::onExecute(const std::vector<Tensor*>& inputs, co
             auto dstOrigin       = outputPtr + index * dst_z_step;
 #ifdef MNN_USE_SSE
             auto inputPadPtrCopy = (int8_t*)inputPadPtr + mInputPad->stride(0);
+            ::memset(inputPadPtrCopy, mResource->mInputZeroPoint + 128, mInputPad->stride(0) * sizeof(int8_t));
 #else
             auto inputPadPtrCopy = inputPadPtr;
-#endif
             ::memset(inputPadPtrCopy, mResource->mInputZeroPoint, mInputPad->stride(0) * sizeof(int8_t));
+#endif
             // Pad inputs
             for (int y = 0; y < src_height; ++y) {
                 auto src = srcOrigin + y * src_width * UNIT;

+ 12 - 9
source/backend/cpu/CPUFloatToInt8.cpp

@@ -12,6 +12,7 @@
 #include "backend/cpu/compute/Int8FunctionsOpt.h"
 #include "core/Macro.h"
 #include "core/TensorUtils.hpp"
+#include "compute/CommonOptFunction.h"
 
 namespace MNN {
 
@@ -19,18 +20,19 @@ CPUFloatToInt8::CPUFloatToInt8(Backend* backend, const MNN::Op* param) : Executi
     auto scale         = param->main_as_QuantizedFloatParam();
     const int scaleLen = scale->tensorScale()->size();
     mClipBits = scale->nbits();
-    mScales.reset(Tensor::createDevice<float>({ALIGN_UP4(scaleLen)}));
+    auto pack = static_cast<CPUBackend*>(backend)->functions()->pack;
+    mScales.reset(Tensor::createDevice<float>({UP_DIV(scaleLen, pack) * pack}));
     mValid = backend->onAcquireBuffer(mScales.get(), Backend::STATIC);
     if (!mValid) {
         return;
     }
     if (1 == scaleLen) {
         mSingle = true;
-        for (int i = 0; i < 4; ++i) {
+        for (int i = 0; i < pack; ++i) {
             mScales->host<float>()[i] = scale->tensorScale()->data()[0];
         }
     } else {
-        memset(mScales->host<float>(), 0, ALIGN_UP4(scaleLen) * sizeof(float));
+        memset(mScales->host<float>(), 0, UP_DIV(scaleLen, pack) * pack * sizeof(float));
         memcpy(mScales->host<float>(), scale->tensorScale()->data(), scaleLen * sizeof(float));
     }
 
@@ -49,13 +51,14 @@ ErrorCode CPUFloatToInt8::onResize(const std::vector<Tensor*>& inputs, const std
 ErrorCode CPUFloatToInt8::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
     const auto input = inputs[0];
     auto output      = outputs[0];
-    MNN_ASSERT(MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[0])->dimensionFormat);
+    auto pack = static_cast<CPUBackend*>(backend())->functions()->pack;
+    auto int8F = static_cast<CPUBackend*>(backend())->int8Functions();
 
     const auto inputDataPtr = input->host<float>();
     auto outputDataPtr      = output->host<int8_t>();
     const auto scaleDataPtr = mScales->host<float>();
     const int channels      = input->channel();
-    int icDiv4        = UP_DIV(channels, 4);
+    int icDiv4        = UP_DIV(channels, pack);
     const int batch         = input->batch();
     const int batchStride   = input->stride(0);
     int oc4Stride           = 1;
@@ -72,10 +75,10 @@ ErrorCode CPUFloatToInt8::onExecute(const std::vector<Tensor*>& inputs, const st
     MNN_CONCURRENCY_BEGIN(tId, total) {
         int bIndex = tId / icDiv4;
         int z = tId % icDiv4;
-        const auto srcChannelPtr   = inputDataPtr + tId * oc4Stride * 4;
-        const auto scaleChannelPtr = scaleDataPtr + z * 4;
-        auto dstChannlePtr         = outputDataPtr + tId * oc4Stride * 4;
-        MNNFloat2Int8(srcChannelPtr, dstChannlePtr, oc4Stride, scaleChannelPtr, mClampMin, mClampMax, mZeroPoint);
+        const auto srcChannelPtr   = inputDataPtr + tId * oc4Stride * pack;
+        const auto scaleChannelPtr = scaleDataPtr + z * pack;
+        auto dstChannlePtr         = outputDataPtr + tId * oc4Stride * pack;
+        int8F->MNNFloat2Int8(srcChannelPtr, dstChannlePtr, oc4Stride, scaleChannelPtr, mClampMin, mClampMax, mZeroPoint);
     }
     MNN_CONCURRENCY_END();
     return NO_ERROR;

+ 12 - 9
source/backend/cpu/CPUInt8ToFloat.cpp

@@ -11,6 +11,7 @@
 #include "core/Concurrency.h"
 #include "core/Macro.h"
 #include "compute/Int8FunctionsOpt.h"
+#include "compute/CommonOptFunction.h"
 #include "core/TensorUtils.hpp"
 
 namespace MNN {
@@ -18,18 +19,19 @@ namespace MNN {
 CPUInt8ToFloat::CPUInt8ToFloat(Backend* backend, const MNN::Op* param) : Execution(backend) {
     auto scale         = param->main_as_QuantizedFloatParam();
     const int scaleLen = scale->tensorScale()->size();
-    mScales.reset(Tensor::createDevice<float>({ALIGN_UP4(scaleLen)}));
+    auto pack = static_cast<CPUBackend*>(backend)->functions()->pack;
+    mScales.reset(Tensor::createDevice<float>({UP_DIV(scaleLen, pack) * pack}));
     mValid = backend->onAcquireBuffer(mScales.get(), Backend::STATIC);
     if (!mValid) {
         return;
     }
     if (1 == scaleLen) {
         mSingle = true;
-        for (int i = 0; i < 4; ++i) {
+        for (int i = 0; i < pack; ++i) {
             mScales->host<float>()[i] = scale->tensorScale()->data()[0];
         }
     } else {
-        memset(mScales->host<float>(), 0, ALIGN_UP4(scaleLen) * sizeof(float));
+        memset(mScales->host<float>(), 0, UP_DIV(scaleLen, pack) * pack * sizeof(float));
         memcpy(mScales->host<float>(), scale->tensorScale()->data(), scaleLen * sizeof(float));
     }
     mZeroPoint = scale->zeroPoint();
@@ -40,13 +42,14 @@ CPUInt8ToFloat::~CPUInt8ToFloat() {
 ErrorCode CPUInt8ToFloat::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
     const auto input = inputs[0];
     auto output      = outputs[0];
-    MNN_ASSERT(MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[0])->dimensionFormat);
+    auto pack = static_cast<CPUBackend*>(backend())->functions()->pack;
+    auto int8F = static_cast<CPUBackend*>(backend())->int8Functions();
 
     const auto inputDataPtr = input->host<int8_t>();
     auto outputDataPtr      = output->host<float>();
     const auto scaleDataPtr = mScales->host<float>();
     const int channels      = input->channel();
-    int icDiv4        = UP_DIV(channels, 4);
+    int icDiv4        = UP_DIV(channels, pack);
     const int batch         = input->batch();
     const int batchStride   = input->stride(0);
     int oc4Stride           = 1;
@@ -62,10 +65,10 @@ ErrorCode CPUInt8ToFloat::onExecute(const std::vector<Tensor*>& inputs, const st
     MNN_CONCURRENCY_BEGIN(tId, total) {
         int bIndex = tId / icDiv4;
         int z = tId % icDiv4;
-        const auto srcChannelPtr   = inputDataPtr + tId * oc4Stride * 4;
-        const auto scaleChannelPtr = scaleDataPtr + z * 4;
-        auto dstChannlePtr         = outputDataPtr + tId * oc4Stride * 4;
-        MNNInt8ScaleToFloat(dstChannlePtr, srcChannelPtr, scaleChannelPtr, oc4Stride, mZeroPoint);
+        const auto srcChannelPtr   = inputDataPtr + tId * oc4Stride * pack;
+        const auto scaleChannelPtr = scaleDataPtr + z * pack;
+        auto dstChannlePtr         = outputDataPtr + tId * oc4Stride * pack;
+        int8F->MNNInt8ScaleToFloat(dstChannlePtr, srcChannelPtr, scaleChannelPtr, oc4Stride, mZeroPoint);
     }
     MNN_CONCURRENCY_END();
 

+ 11 - 2
source/backend/cpu/CPULayerNorm.cpp

@@ -32,7 +32,7 @@ private:
     std::vector<int> axis_;
     int inner_size_ = 1;
     int outter_size_ = 1;
-
+    int group_ = 1;
     float epsilon_ = 0.001;
 
     std::unique_ptr<Tensor> gamma_;
@@ -48,7 +48,7 @@ CPULayerNorm::CPULayerNorm(const MNN::Op* op, Backend* backend)
     for (int i = 0; i < axis_size; ++i) {
         axis_[i] = layer_norm_param->axis()->Get(i);
     }
-
+    group_ = layer_norm_param->group();
     epsilon_ = layer_norm_param->epsilon();
 
     if (layer_norm_param->gamma() && layer_norm_param->beta()) {
@@ -96,6 +96,14 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
     outter_size_ = 1;
     inner_size_ = 1;
     int rank = inputs.at(0)->dimensions();
+    if (group_ > 1) {
+        outter_size_ = inputs.at(0)->length(0) * group_;
+        for (int i = 1; i < rank; i++) {
+            inner_size_ *= inputs.at(0)->length(i);
+        }
+        inner_size_ /= group_;
+        return NO_ERROR;
+    }
     std::vector<int> axis(axis_.size());
     for (int i = 0; i < axis_.size(); ++i) {
         if (axis_[i] < 0) {
@@ -103,6 +111,7 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
         }
     }
     std::sort(axis.begin(), axis.end());
+
     for (int i = 0; i < rank - axis.size(); ++i) {
         outter_size_ *= inputs.at(0)->length(i);
     }

+ 14 - 18
source/backend/cpu/CPUOPRegister.cpp

@@ -17,9 +17,13 @@ extern void ___CPUMatMulCreator__OpType_MatMul__();
 extern void ___CPUMomentsCreator__OpType_Moments__();
 extern void ___CPUSegmentMeanCreator__OpType_Segment__();
 extern void ___CPUInstanceNormCreator__OpType_InstanceNorm__();
+extern void ___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__();
 extern void ___CPUWhereCreator__OpType_Where__();
+extern void ___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__();
 extern void ___CPUDeconvolutionCreator__OpType_Deconvolution__();
 extern void ___CPUBinaryCreator__OpType_BinaryOp__();
+extern void ___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__();
+extern void ___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__();
 extern void ___CPUPoolCreator__OpType_Pooling__();
 extern void ___CPUScatterNdCreator__OpType_ScatterNd__();
 extern void ___CPUPluginCreator__OpType_Plugin__();
@@ -36,6 +40,7 @@ extern void ___CPUDepthwiseConvInt8Creator__OpType_DepthwiseConvInt8__();
 extern void ___CPUOneHotCreator__OpType_OneHot__();
 extern void ___CPUPoolInt8Creator__OpType_PoolInt8__();
 extern void ___CPUMatrixBandPartCreator__OpType_MatrixBandPart__();
+extern void ___CPUQuantizedAddCreator__OpType_QuantizedAdd__();
 extern void ___CPUDeconvolutionDepthwiseCreator__OpType_DeconvolutionDepthwise__();
 extern void ___CPUFloatToInt8Creator__OpType_FloatToInt8__();
 extern void ___CPULinSpaceCreator__OpType_LinSpace__();
@@ -45,6 +50,8 @@ extern void ___CPURasterFactory__OpType_Raster__();
 extern void ___CPURasterFactory__OpType_While__();
 extern void ___CPUConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__();
 extern void ___CPURangeCreator__OpType_Range__();
+extern void ___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__();
+extern void ___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__();
 extern void ___ConvolutionFactory__OpType_Convolution__();
 extern void ___CPUConvInt8Creator__OpType_ConvInt8__();
 extern void ___CPURNNSequenceGRUCreator__OpType_RNNSequenceGRU__();
@@ -55,15 +62,6 @@ extern void ___CPUSetDiff1DCreator__OpType_SetDiff1D__();
 extern void ___CPUReduceJoinCreator__OpType_ReduceJoin__();
 extern void ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__();
 extern void ___CPULayerNormCreator__OpType_LayerNorm__();
-#ifdef MNN_SUPPORT_TFLITE_QUAN
-extern void ___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__();
-extern void ___CPUQuantizedAddCreator__OpType_QuantizedAdd__();
-extern void ___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__();
-extern void ___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__();
-extern void ___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__();
-extern void ___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__();
-extern void ___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__();
-#endif  // MNN_SUPPORT_TFLITE_QUAN
 
 void registerCPUOps() {
 ___CPUCropAndResizeCreator__OpType_CropAndResize__();
@@ -83,9 +81,13 @@ ___CPUMatMulCreator__OpType_MatMul__();
 ___CPUMomentsCreator__OpType_Moments__();
 ___CPUSegmentMeanCreator__OpType_Segment__();
 ___CPUInstanceNormCreator__OpType_InstanceNorm__();
+___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__();
 ___CPUWhereCreator__OpType_Where__();
+___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__();
 ___CPUDeconvolutionCreator__OpType_Deconvolution__();
 ___CPUBinaryCreator__OpType_BinaryOp__();
+___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__();
+___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__();
 ___CPUPoolCreator__OpType_Pooling__();
 ___CPUScatterNdCreator__OpType_ScatterNd__();
 ___CPUPluginCreator__OpType_Plugin__();
@@ -102,6 +104,7 @@ ___CPUDepthwiseConvInt8Creator__OpType_DepthwiseConvInt8__();
 ___CPUOneHotCreator__OpType_OneHot__();
 ___CPUPoolInt8Creator__OpType_PoolInt8__();
 ___CPUMatrixBandPartCreator__OpType_MatrixBandPart__();
+___CPUQuantizedAddCreator__OpType_QuantizedAdd__();
 ___CPUDeconvolutionDepthwiseCreator__OpType_DeconvolutionDepthwise__();
 ___CPUFloatToInt8Creator__OpType_FloatToInt8__();
 ___CPULinSpaceCreator__OpType_LinSpace__();
@@ -111,6 +114,8 @@ ___CPURasterFactory__OpType_Raster__();
 ___CPURasterFactory__OpType_While__();
 ___CPUConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__();
 ___CPURangeCreator__OpType_Range__();
+___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__();
+___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__();
 ___ConvolutionFactory__OpType_Convolution__();
 ___CPUConvInt8Creator__OpType_ConvInt8__();
 ___CPURNNSequenceGRUCreator__OpType_RNNSequenceGRU__();
@@ -121,14 +126,5 @@ ___CPUSetDiff1DCreator__OpType_SetDiff1D__();
 ___CPUReduceJoinCreator__OpType_ReduceJoin__();
 ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__();
 ___CPULayerNormCreator__OpType_LayerNorm__();
-#ifdef MNN_SUPPORT_TFLITE_QUAN
-___CPUDepthwiseCreator__OpType_QuantizedDepthwiseConv2D__();
-___CPUQuantizedAddCreator__OpType_QuantizedAdd__();
-___CPUQuantizedAvgPoolCreator__OpType_QuantizedAvgPool__();
-___CPUQuantizedLogisticCreator__OpType_QuantizedLogistic__();
-___CPUQuantizedMaxPoolCreator__OpType_QuantizedMaxPool__();
-___CPUQuantizedSoftmaxCreator__OpType_QuantizedSoftmax__();
-___CPUTFQuantizedConv2DCreator__OpType_TfQuantizedConv2D__();
-#endif  // MNN_SUPPORT_TFLITE_QUAN
 }
 }

+ 13 - 17
source/backend/cpu/CPURaster.cpp

@@ -176,7 +176,16 @@ ErrorCode CPURaster::onResize(const std::vector<Tensor *> &inputs, const std::ve
     auto des = TensorUtils::getDescribe(input);
     MNN_ASSERT(des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL);
     auto outputDes = TensorUtils::getDescribe(output);
+    auto bytes = output->getType().bytes();
     mNeedZero = !TensorUtils::regionIsFull(input);
+    mZeroPoint = 0;
+    if (bytes == 1 && TensorUtils::getDescribe(output)->quantAttr != nullptr) {
+#ifdef MNN_USE_SSE
+        mZeroPoint = (int)TensorUtils::getDescribe(output)->quantAttr->zero + 128;
+#else
+        mZeroPoint = (int)TensorUtils::getDescribe(output)->quantAttr->zero;
+#endif
+    }
     mTempInput.clear();
     mFastBlit.clear();
     mCacheRegions.clear();
@@ -412,16 +421,6 @@ static void _1BitcopyWithStride(uint8_t* dstO, const uint8_t* srcO, int size, in
         dst+=ds;
     }
 }
-static void _8BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds) {
-    auto src = (float*)srcO;
-    auto dst = (float*)dstO;
-    for (int i=0; i<size; ++i) {
-        Vec4::save(dst, Vec4::load(src));
-        Vec4::save(dst + 4, Vec4::load(src + 4));
-        src+= (8 * stride);
-        dst+= (8 * ds);
-    }
-}
 static void _4BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds) {
     auto src = (float*)srcO;
     auto dst = (float*)dstO;
@@ -457,14 +456,11 @@ void CPURaster::executeFaster(const std::vector<Tensor *> &inputs, const std::ve
     auto core = static_cast<const CPUBackend*>(backend())->functions();
     auto threadNum = static_cast<CPUBackend*>(backend())->threadNumber();
     if (mNeedZero) {
-        ::memset(output->host<void>(), 0, static_cast<CPUBackend*>(backend())->getTensorSize(output) * bytes);
+        ::memset(output->host<void>(), mZeroPoint, static_cast<CPUBackend*>(backend())->getTensorSize(output) * bytes);
     }
     auto byteC4 = bytes * core->pack;
     auto C4proc = _4BitcopyWithStride;
     switch (byteC4) {
-        case 32:
-            C4proc = _8BitcopyWithStrideC4;
-            break;
         case 16:
             C4proc = _4BitcopyWithStrideC4;
             break;
@@ -475,7 +471,7 @@ void CPURaster::executeFaster(const std::vector<Tensor *> &inputs, const std::ve
             C4proc = _4BitcopyWithStride;
             break;
         default:
-            MNN_ASSERT(false);
+            C4proc = core->MNNSelectBlitFunction(byteC4);
             break;
     }
     MNN_CONCURRENCY_BEGIN(tId, threadNum) {
@@ -644,9 +640,9 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &inputs, const std::v
     }
     if (mNeedZero) {
         if (mTempOutput == nullptr) {
-            ::memset(output->host<void>(), 0, outputEleSize * bytes);
+            ::memset(output->host<void>(), mZeroPoint, outputEleSize * bytes);
         } else {
-            ::memset(mTempOutput->host<void>(), 0, mTempOutput->elementSize() * bytes);
+            ::memset(mTempOutput->host<void>(), mZeroPoint, mTempOutput->elementSize() * bytes);
         }
     }
     for (auto& iter : mTempInput) {

+ 1 - 0
source/backend/cpu/CPURaster.hpp

@@ -35,6 +35,7 @@ private:
     bool mFast = false;
     int mSingleConvert = 0;
     std::vector<Tensor::InsideDescribe::Region> mCacheRegions;
+    int32_t mZeroPoint = 0;
 };
 }
 #endif

+ 3 - 2
source/backend/cpu/CPUScale.cpp

@@ -36,10 +36,11 @@ CPUScale::CPUScale(const Op* op, Backend* bn) : MNN::Execution(bn) {
         ::memcpy(mScaleBias->host<float>(), scale->scaleData()->data(), outputCount * sizeof(float));
     }
     if (nullptr != scale->biasData() && nullptr != scale->biasData()->data()) {
+        auto biasPtr = mScaleBias->host<uint8_t>() + mScaleBias->length(1);
         if (core->bytes < 4) {
-            core->MNNFp32ToLowp(scale->biasData()->data(), (int16_t*)(mScaleBias->host<uint8_t>() + 1 * mScaleBias->length(1)), outputCount);
+            core->MNNFp32ToLowp(scale->biasData()->data(), reinterpret_cast<int16_t*>(biasPtr), outputCount);
         } else {
-            ::memcpy(mScaleBias->host<float>() + ALIGN_UP4(outputCount), scale->biasData()->data(), outputCount * sizeof(float));
+            ::memcpy(biasPtr, scale->biasData()->data(), outputCount * sizeof(float));
         }
     }
 }

+ 5 - 1
source/backend/cpu/CPUSoftmax.cpp

@@ -77,7 +77,11 @@ int CPUSoftmax::_softmaxCommon(const float *srcData, float *dstData, int inside,
             realSize = totalSize - start;
         }
         if (realSize > 0) {
-            MNNExp(dstData + start, dstData + start, realSize);
+            float ab[2] = {
+                -1.0f,
+                0.0f
+            };
+            MNNExp(dstData + start, dstData + start, ab, realSize);
         }
     }
     MNN_CONCURRENCY_END();

+ 8 - 8
source/backend/cpu/CPUTensorConvert.cpp

@@ -121,9 +121,9 @@ ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN
             if (core->bytes == bitLength) {
                 proc = decltype(proc)(core->MNNUnpackCUnitTranspose);
             } else if (bitLength == 1) {
-                proc = decltype(proc)(MNNPackTransposeUint8);
+                proc = decltype(proc)(core->MNNUnpackCUnitTransposeInt8);
             } else if (bitLength == 2) {
-                proc = decltype(proc)(MNNPackTransposeInt16);
+                proc = decltype(proc)(core->MNNUnpackCUnitTransposeInt16);
             }
             if (nullptr == proc) {
                 return NOT_SUPPORT;
@@ -133,9 +133,9 @@ ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN
             if (core->bytes == bitLength) {
                 proc = decltype(proc)(core->MNNUnpackCUnit);
             } else if (bitLength == 1) {
-                proc = decltype(proc)(MNNUnpackC4Uint8);
+                proc = decltype(proc)(core->MNNUnpackCUnitInt8);
             } else if (bitLength == 2) {
-                proc = decltype(proc)(MNNUnpackC4Int16);
+                proc = decltype(proc)(core->MNNUnpackCUnitInt16);
             }
             if (nullptr == proc) {
                 return NOT_SUPPORT;
@@ -191,9 +191,9 @@ ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN
             if (core->bytes == bitLength) {
                 proc = decltype(proc)(core->MNNPackCUnitTranspose);
             } else if (bitLength == 1) {
-                proc = decltype(proc)(MNNUnpackTransposeUint8);
+                proc = decltype(proc)(core->MNNPackCUnitTransposeInt8);
             } else if (bitLength == 2) {
-                proc = decltype(proc)(MNNUnpackTransposeInt16);
+                proc = decltype(proc)(core->MNNPackCUnitTransposeInt16);
             }
             if (nullptr == proc) {
                 return NOT_SUPPORT;
@@ -205,9 +205,9 @@ ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN
             if (core->bytes == bitLength) {
                 proc = decltype(proc)(core->MNNPackCUnit);
             } else if (bitLength == 1) {
-                proc = decltype(proc)(MNNPackC4Uint8);
+                proc = decltype(proc)(core->MNNPackCUnitInt8);
             } else if (bitLength == 2) {
-                proc = decltype(proc)(MNNPackC4Int16);
+                proc = decltype(proc)(core->MNNPackCUnitInt16);
             }
             if (nullptr == proc) {
                 return NOT_SUPPORT;

+ 12 - 7
source/backend/cpu/CPUUnary.cpp

@@ -40,17 +40,20 @@ static void _Square(void* out, const void* inp, int realSize) {
 static void _EXP(void* outRaw, const void* inpRaw, int realSize) {
     auto out = (float*)outRaw;
     auto inp = (const float*)inpRaw;
-    MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-    MNNExp(out, out, realSize);
+    float offset[2] = {
+        1.0f,
+        0.0f
+    };
+    MNNExp(out, inp, offset, realSize);
 }
 static void _EXPM1(void* outRaw, const void* inpRaw, int realSize) {
     auto out = (float*)outRaw;
     auto inp = (const float*)inpRaw;
-    MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-    MNNExp(out, out, realSize);
-    for (int i=0; i<realSize; ++i) {
-        out[i] = out[i] - 1.0f;
-    }
+    float offset[2] = {
+        1.0f,
+        -1.0f
+    };
+    MNNExp(out, inp, offset, realSize);
 }
 
 MNNUnaryExecute CPUUnary::selectForFloat(int type, int precision) {
@@ -126,6 +129,8 @@ MNNUnaryExecute CPUUnary::selectForFloat(int type, int precision) {
             return (MNNUnaryExecute)MNNHardSwishCommon;
         case UnaryOpOperation_GELU:
             return (MNNUnaryExecute)MNNGeluCommon;
+        case UnaryOpOperation_GELU_STANDARD:
+            return (MNNUnaryExecute)MNNGeluStandardCommon;
         default:
             MNN_ASSERT(false);
             break;

+ 1 - 1
source/backend/cpu/arm/CommonOptFunctionNeon.cpp

@@ -3,7 +3,7 @@
 #ifdef MNN_USE_NEON
 #include <arm_neon.h>
 #include "./FunctionSummary.hpp"
-#include "core/MemoryFormater.h"
+#include "common/MemoryFormater.h"
 
 extern "C" {
 void MNNTranspose32Bit4x4(int32_t* dstO, const int32_t* srcO, int32_t* dim);

+ 20 - 14
source/backend/cpu/arm/arm32/MNNExpC8.S

@@ -12,14 +12,19 @@
 #include "MNNAsmGlobal.h"
 .text
 .align 5
-//void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8)
+//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
 asm_function MNNExpC8
 
-//r0: dest, r1:source, r2:parameters, r3:countC8
-push {r4, lr}
-vpush {q5, q6}
+//r0: dest, r1:source, r2: offset, r3:parameters, r4:countC8
+push {r4, r5, lr}
+ldr r4, [sp, #12]
+vpush {q4, q5}
+ldr r5, [r2, #0]
+vdup.32 q4, r5 // Alpha
+ldr r5, [r2, #4]
+vdup.32 q5, r5 // Beta
 
-vld1.32 {q0, q1}, [r2]
+vld1.32 {q0, q1}, [r3]
 
 vmov.i32 q2, #87
 vcvt.f32.s32 q2, q2
@@ -28,14 +33,13 @@ vneg.f32 q3, q2
 Loop:
 
 vld1.32 {q8, q9}, [r1]!
+vmul.f32 q8, q8, q4
+vmul.f32 q9, q9, q4
 
 vmin.f32 q8, q8, q2
 vmin.f32 q9, q9, q2
-vmax.f32 q8, q8, q3
-vmax.f32 q9, q9, q3
-
-vneg.f32 q10, q8
-vneg.f32 q11, q9
+vmax.f32 q10, q8, q3
+vmax.f32 q11, q9, q3
 
 vmul.f32 q8, q10, d0[1]
 vmul.f32 q9, q11, d0[1]
@@ -72,14 +76,16 @@ vshl.i32 q9, q9, #23
 vadd.i32 q12, q12, q8
 vadd.i32 q13, q13, q9
 
-vst1.32 {q12, q13}, [r0]!
+vadd.f32 q12, q12, q5
+vadd.f32 q13, q13, q5
 
+vst1.32 {q12, q13}, [r0]!
 
-subs r3, r3, #1
+subs r4, r4, #1
 bne Loop
 
-vpop {q5, q6}
-pop {r4, pc}
+vpop {q4, q5}
+pop {r4, r5, pc}
 
 
 #endif

+ 6 - 1
source/backend/cpu/arm/arm32/MNNPackC4ForMatMul_A.S

@@ -43,7 +43,12 @@ mul r8, r12, r8
 add r0, r0, r7
 add r0, r0, r8
 
+mov r2, #12         // the fast-pack-eSize
+mul r2, r12, r2     // fast-pack-eSize * sizeof(dataType)
+cmp r2, r11         // check eP==fast-pack-eSize
+
 ldr r2, [r3, #0] // e
+bne Right
 
 Body:
 cmp r2, #12
@@ -220,4 +225,4 @@ bne LoopNumber
 pop {r4-r8, r10, r11, pc}
 
 #endif
-#endif
+#endif

+ 0 - 4
source/backend/cpu/arm/arm32/MNNPackedSparseMatMulEpx4.S

@@ -175,7 +175,6 @@ loop_e8:
 
         lsr r0, r5, #2 // NC4HW4
         mul r0, r0, r12
-        vmov r6, r10, d15 // r10 useless
         add r3, r3, r0 // blockC += (h >> 2) * cStride
 
     loop_e8h1:
@@ -319,7 +318,6 @@ beq loop_e2
 
         lsr r0, r5, #2 // NC4HW4
         mul r0, r0, r12
-        vmov r6, r10, d15 // r10 useless
         add r3, r3, r0 // blockC += (h >> 2) * cStride
 
     loop_e4h1:
@@ -437,7 +435,6 @@ beq loop_e1
 
         lsr r0, r5, #2 // NC4HW4
         mul r0, r0, r12
-        vmov r6, r10, d15 // r10 useless
         add r3, r3, r0 // blockC += (h >> 2) * cStride
 
     loop_e2h1:
@@ -547,7 +544,6 @@ beq loop_end
 
         lsr r0, r5, #2 // NC4HW4
         mul r0, r0, r12
-        vmov r6, r10, d15 // r10 useless
         add r3, r3, r0 // blockC += (h >> 2) * cStride
 
     loop_e1h1:

+ 319 - 0
source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S

@@ -0,0 +1,319 @@
+//
+//  MNNPackedSparseQuantMatMulEpx1.S
+//  MNN
+//
+//  Created by MNN on 2021/05/10.
+//  Copyright © 2018-2021 Alibaba Group Holding Limited
+//
+//
+
+#ifdef __arm__
+#ifndef __aarch64__
+
+#include "MNNAsmGlobal.h"
+#define sizeof_value 4
+#define sizeof_value_lg2 2
+#define sparse_blockoc 4
+
+.text
+.align 5
+// caution!!! this is 8 * 1 Sparse MatMul
+asm_function MNNPackedSparseQuantMatMulEpx1
+
+// void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam,
+// const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
+//Auto load: r0: C, r1:A, r2:B, r3:sparseQuantParam,
+//load from stack r4:QuanPostTreatParameters, r5:NNZMap, r6:dataOffsetMap
+
+push {r4-r8, r10, r11, lr}
+vpush {q4-q7}
+#define push_registers_bytes (8 * 4 + 4 * 16)
+ldr r4, [sp, #push_registers_bytes]
+ldr r7, [r4, #8]
+ldr r8, [r4, #12]
+vmov.f32 q4, #0.5
+vmov.f32 q5, #-0.5
+vdup.32 q6, r7 // max
+vdup.32 q7, r8 // min
+
+// r0: C
+// r1: A
+// r2: B
+// r3: sparseQuantParam mem(6*4byte) [eSize, eP, aStride, l, h, cStride]
+// r4: QuanPostTreatParameters mem(4*4byte) [scale, bias, max, min]
+// r5: NNZMap
+// r6: dataOffsetMap
+// r7: scale
+// r8: bias
+// r10: loop_counter (loop_e8 / loop_e4 / loop_e2 / loop_e1), h
+// r11: loop_counter (loop_e8h1 / loop_e4h1 / loop_e2h1 / loop_e1h1)
+// r12: loop_counter (loop_e8h1l1 / loop_e4h1l1 / loop_e2h1l1 / loop_e1h1l1)
+// lr: temp var
+
+ldr r10, [r3]
+loop_e8:
+    cmp r10, #8
+    blt loop_e4
+    sub r10, r10, #8
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r10, [r3, #16] // h
+    mov r11, #0
+    loop_e8h1:
+        vld1.32 d16[0], [r8]!
+        vdup.32 q8, d16[0]
+        vdup.32 q9, d16[0]
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e8h1_end
+        loop_e8h1l1:
+            vld1.8 d0[0], [r2]!
+            vld1.8 d2, [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d2, d0[0]
+            vmlal.s16 q9, d3, d0[0]
+
+            bne loop_e8h1l1
+        loop_e8h1_end:
+            vld1.32 d0[0], [r7]!
+            vcvt.f32.s32 q8, q8
+            vcvt.f32.s32 q9, q9
+            vmul.f32 q8, q8, d0[0]
+            vmul.f32 q9, q9, d0[0]
+            vcgt.f32 q0, q8, #0
+            vcgt.f32 q1, q9, #0
+            vbsl.f32 q0, q4, q5
+            vbsl.f32 q1, q4, q5
+            vadd.f32 q8, q8, q0
+            vadd.f32 q9, q9, q1
+            vcvt.s32.f32 q8, q8
+            vcvt.s32.f32 q9, q9
+            vmin.s32 q8, q8, q6
+            vmin.s32 q9, q9, q6
+            vmax.s32 q8, q8, q7
+            vmax.s32 q9, q9, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s32 d1, q9
+            vqmovn.s16 d0, q0
+            mov lr, #4
+            vst1.8 d0[0], [r0], lr
+            vst1.8 d0[1], [r0], lr
+            vst1.8 d0[2], [r0], lr
+            vst1.8 d0[3], [r0], lr
+            vst1.8 d0[4], [r0], lr
+            vst1.8 d0[5], [r0], lr
+            vst1.8 d0[6], [r0], lr
+            vst1.8 d0[7], [r0], lr
+            sub r0, r0, lr, lsl #3
+            add r11, r11, #1
+            ands lr, r11, #0x03
+            addne r0, r0, #1
+            ldr lr, [r3, #20] // cStride
+            subeq lr, lr, #3
+            addeq r0, r0, lr
+            cmp r11, r10
+        blt loop_e8h1
+        pop {r0-r2, r10}
+        add r0, r0, #32
+        add r1, r1, #8
+    b loop_e8
+
+loop_e4:
+    cmp r10, #4
+    blt loop_e2
+    sub r10, r10, #4
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r10, [r3, #16] // h
+    mov r11, #0
+    loop_e4h1:
+        vld1.32 d16[0], [r8]!
+        vdup.32 q8, d16[0]
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e4h1_end
+        loop_e4h1l1:
+            vld1.8 d0[0], [r2]!
+            vld1.32 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d2, d0[0]
+            bne loop_e4h1l1
+        loop_e4h1_end:
+            vld1.32 d0[0], [r7]!
+            vcvt.f32.s32 q8, q8
+            vmul.f32 q8, q8, d0[0]
+            vcgt.f32 q0, q8, #0
+            vbsl.f32 q0, q4, q5
+            vadd.f32 q8, q8, q0
+            vcvt.s32.f32 q8, q8
+            vmin.s32 q8, q8, q6
+            vmax.s32 q8, q8, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s16 d0, q0
+            mov lr, #4
+            vst1.8 d0[0], [r0], lr
+            vst1.8 d0[1], [r0], lr
+            vst1.8 d0[2], [r0], lr
+            vst1.8 d0[3], [r0], lr
+            sub r0, r0, lr, lsl #2
+            add r11, r11, #1
+            ands lr, r11, #0x03
+            addne r0, r0, #1
+            ldr lr, [r3, #20] // cStride
+            subeq lr, lr, #3
+            addeq r0, r0, lr
+            cmp r11, r10
+        blt loop_e4h1
+        pop {r0-r2, r10}
+        add r0, r0, #16
+        add r1, r1, #4
+    b loop_e4
+
+loop_e2:
+    cmp r10, #2
+    blt loop_e1
+    sub r10, r10, #2
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r10, [r3, #16] // h
+    mov r11, #0
+    loop_e2h1:
+        vld1.32 d16[0], [r8]!
+        vdup.32 d16, d16[0]
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e2h1_end
+        loop_e2h1l1:
+            vld1.8 d0[0], [r2]!
+            vld1.16 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d2, d0[0]
+            bne loop_e2h1l1
+        loop_e2h1_end:
+            vld1.32 d0[0], [r7]!
+            vcvt.f32.s32 d16, d16
+            vmul.f32 d16, d16, d0[0]
+            vcgt.f32 d0, d16, #0
+            vbsl.f32 d0, d8, d10
+            vadd.f32 d16, d16, d0
+            vcvt.s32.f32 d16, d16
+            vmin.s32 d16, d16, d12
+            vmax.s32 d16, d16, d14
+            vqmovn.s32 d0, q8
+            vqmovn.s16 d0, q0
+            mov lr, #4
+            vst1.8 d0[0], [r0], lr
+            vst1.8 d0[1], [r0], lr
+            sub r0, r0, lr, lsl #1
+            add r11, r11, #1
+            ands lr, r11, #0x03
+            addne r0, r0, #1
+            ldr lr, [r3, #20] // cStride
+            subeq lr, lr, #3
+            addeq r0, r0, lr
+            cmp r11, r10
+        blt loop_e2h1
+    pop {r0-r2, r10}
+    add r0, r0, #8
+    add r1, r1, #2
+    b loop_e2
+
+loop_e1:
+    cmp r10, #1
+    blt End
+    sub r10, r10, #1
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+
+    push {r0-r2, r10}
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r10, [r3, #16] // h
+    mov r11, #0
+    loop_e1h1:
+        vld1.32 d16[0], [r8]!
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e1h1_end
+        loop_e1h1l1:
+            vld1.8 d0[0], [r2]!
+            vld1.8 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d2, d0[0]
+            bne loop_e1h1l1
+        loop_e1h1_end:
+            vld1.32 d0[0], [r7]!
+            vcvt.f32.s32 d16, d16
+            vmul.f32 d16, d16, d0[0]
+            vcgt.f32 d0, d16, #0
+            vbsl.f32 d0, d8, d10
+            vadd.f32 d16, d16, d0
+            vcvt.s32.f32 d16, d16
+            vmin.s32 d16, d16, d12
+            vmax.s32 d16, d16, d14
+            vqmovn.s32 d0, q8
+            vqmovn.s16 d0, q0
+            mov lr, #4
+            vst1.8 d0[0], [r0]
+            add r11, r11, #1
+            ands lr, r11, #0x03
+            addne r0, r0, #1
+            ldr lr, [r3, #20] // cStride
+            subeq lr, lr, #3
+            addeq r0, r0, lr
+            cmp r11, r10
+        blt loop_e1h1
+    pop {r0-r2, r10}
+    add r0, r0, #4
+    add r1, r1, #1
+    b loop_e1
+
+End:
+vpop {q4-q7}
+pop {r4-r8, r10, r11, pc}
+
+#undef push_registers_bytes
+#undef sizeof_value
+#undef sizeof_value_lg2
+#undef sparse_blockoc
+
+#endif
+#endif
+

+ 352 - 0
source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S

@@ -0,0 +1,352 @@
+//
+//  MNNPackedSparseQuantMatMulEpx4.S
+//  MNN
+//
+//  Created by MNN on 2021/06/23.
+//  Copyright © 2018-2021 Alibaba Group Holding Limited
+//
+//
+
+#ifdef __arm__
+#ifndef __aarch64__
+
+#include "MNNAsmGlobal.h"
+#define sizeof_value 4
+#define sizeof_value_lg2 2
+#define sparse_blockoc 4
+
+.macro TYPE_CVT op, z0, z1, z2, z3
+    \op \z0, \z0
+    \op \z1, \z1
+    \op \z2, \z2
+    \op \z3, \z3
+.endm
+
+.macro CLAMP op, z0, z1, z2, z3, m0
+    \op \z0, \z0, \m0
+    \op \z1, \z1, \m0
+    \op \z2, \z2, \m0
+    \op \z3, \z3, \m0
+.endm
+
+.macro SCALE z0, z1, z2, z3, scale
+    vmul.f32 \z0, \z0, \scale
+    vmul.f32 \z1, \z1, \scale
+    vmul.f32 \z2, \z2, \scale
+    vmul.f32 \z3, \z3, \scale
+.endm
+
+.macro ROUND_MODE z0, z1, z2, z3
+    vcgt.f32 q0, \z0, #0
+    vcgt.f32 q1, \z1, #0
+    vcgt.f32 q2, \z2, #0
+    vcgt.f32 q3, \z3, #0
+    vbsl.f32 q0, q4, q5
+    vbsl.f32 q1, q4, q5
+    vbsl.f32 q2, q4, q5
+    vbsl.f32 q3, q4, q5
+    vadd.f32 \z0, \z0, q0
+    vadd.f32 \z1, \z1, q1
+    vadd.f32 \z2, \z2, q2
+    vadd.f32 \z3, \z3, q3
+.endm
+
+.text
+.align 5
+// caution!!! this is 8 * 4 Sparse MatMul
+asm_function MNNPackedSparseQuantMatMulEpx4
+
+// void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam,
+// const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
+//Auto load: r0: C, r1:A, r2:B, r3:sparseQuantParam,
+//load from stack r4:QuanPostTreatParameters, r5:NNZMap, r6:dataOffsetMap
+
+// var not defined: bias,
+
+push {r4-r8, r10, r11, lr}
+vpush {q4-q7}
+#define push_registers_bytes (8 * 4 + 4 * 16)
+ldr r4, [sp, #push_registers_bytes]
+ldr r7, [r4, #8]
+ldr r8, [r4, #12]
+vmov.f32 q4, #0.5
+vmov.f32 q5, #-0.5
+vdup.32 q6, r7 // max
+vdup.32 q7, r8 // min
+
+// r0: C
+// r1: A
+// r2: B
+// r3: sparseQuantParam mem(6*4byte) [eSize, eP, aStride, l, h, cStride]
+// r4: QuanPostTreatParameters mem(4*4byte) [scale, bias, max, min]
+// r5: NNZMap
+// r6: dataOffsetMap
+// r7: scale
+// r8: bias
+// r10: loop_counter (loop_e8 / loop_e4 / loop_e2 / loop_e1), cStride
+// r11: loop_counter (loop_e8h4 / loop_e4h4 / loop_e2h4 / loop_e1h4)
+// r12: loop_counter (loop_e8h4l1 / loop_e4h4l1 / loop_e2h4l1 / loop_e1h4l1)
+// lr: temp var
+
+ldr r10, [r3]
+loop_e8:
+    cmp r10, #8
+    blt loop_e4
+    sub r10, r10, #8
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr r10, [r3, #20] // cStride
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r11, [r3, #16] // h
+    lsr r11, r11, #2 // hDiv4 (C4)
+    loop_e8h4:
+        vld1.32 q8, [r8]!
+        vmov q9, q8
+        vmov q10, q8
+        vmov q11, q8
+        vmov q12, q8
+        vmov q13, q8
+        vmov q14, q8
+        vmov q15, q8
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e8h4_end
+        loop_e8h4l1:
+            vld1.32 d0[0], [r2]!
+            vld1.8 d2, [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d0, d2[0]
+            vmlal.s16 q9, d0, d2[1]
+            vmlal.s16 q10, d0, d2[2]
+            vmlal.s16 q11, d0, d2[3]
+            vmlal.s16 q12, d0, d3[0]
+            vmlal.s16 q13, d0, d3[1]
+            vmlal.s16 q14, d0, d3[2]
+            vmlal.s16 q15, d0, d3[3]
+
+            bne loop_e8h4l1
+        loop_e8h4_end:
+            vld1.32 q0, [r7]!
+            TYPE_CVT vcvt.f32.s32, q8, q9, q10, q11
+            TYPE_CVT vcvt.f32.s32, q12, q13, q14, q15
+            SCALE q8, q9, q10, q11, q0
+            SCALE q12, q13, q14, q15, q0
+            ROUND_MODE q8, q9, q10, q11
+            ROUND_MODE q12, q13, q14, q15
+            TYPE_CVT vcvt.s32.f32, q8, q9, q10, q11
+            TYPE_CVT vcvt.s32.f32, q12, q13, q14, q15
+            CLAMP vmin.s32, q8, q9, q10, q11, q6
+            CLAMP vmin.s32, q12, q13, q14, q15, q6
+            CLAMP vmax.s32, q8, q9, q10, q11, q7
+            CLAMP vmax.s32, q12, q13, q14, q15, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s32 d1, q9
+            vqmovn.s32 d2, q10
+            vqmovn.s32 d3, q11
+            vqmovn.s32 d4, q12
+            vqmovn.s32 d5, q13
+            vqmovn.s32 d6, q14
+            vqmovn.s32 d7, q15
+            vqmovn.s16 d0, q0
+            vqmovn.s16 d1, q1
+            vqmovn.s16 d2, q2
+            vqmovn.s16 d3, q3
+            vst1.8 {q0, q1}, [r0], r10
+            subs r11, r11, #1
+        bne loop_e8h4
+        pop {r0-r2, r10}
+        add r0, r0, #32
+        add r1, r1, #8
+    b loop_e8
+
+loop_e4:
+    cmp r10, #4
+    blt loop_e2
+    sub r10, r10, #4
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr r10, [r3, #20] // cStride
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r11, [r3, #16] // h
+    lsr r11, r11, #2 // hDiv4 (C4)
+    loop_e4h4:
+        vld1.32 q8, [r8]!
+        vmov q9, q8
+        vmov q10, q8
+        vmov q11, q8
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e4h4_end
+        loop_e4h4l1:
+            vld1.32 d0[0], [r2]!
+            vld1.32 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d0, d2[0]
+            vmlal.s16 q9, d0, d2[1]
+            vmlal.s16 q10, d0, d2[2]
+            vmlal.s16 q11, d0, d2[3]
+            bne loop_e4h4l1
+        loop_e4h4_end:
+            vld1.32 q0, [r7]!
+            TYPE_CVT vcvt.f32.s32, q8, q9, q10, q11
+            SCALE q8, q9, q10, q11, q0
+            ROUND_MODE q8, q9, q10, q11
+            TYPE_CVT vcvt.s32.f32, q8, q9, q10, q11
+            CLAMP vmin.s32, q8, q9, q10, q11, q6
+            CLAMP vmax.s32, q8, q9, q10, q11, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s32 d1, q9
+            vqmovn.s32 d2, q10
+            vqmovn.s32 d3, q11
+            vqmovn.s16 d0, q0
+            vqmovn.s16 d1, q1
+            vst1.8 {q0}, [r0], r10
+            subs r11, r11, #1
+        bne loop_e4h4
+        pop {r0-r2, r10}
+        add r0, r0, #16
+        add r1, r1, #4
+    b loop_e4
+
+loop_e2:
+    cmp r10, #2
+    blt loop_e1
+    sub r10, r10, #2
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr r10, [r3, #20] // cStride
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r11, [r3, #16] // h
+    lsr r11, r11, #2 // hDiv4 (C4)
+    loop_e2h4:
+        vld1.32 q8, [r8]!
+        vmov q9, q8
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e2h4_end
+        loop_e2h4l1:
+            vld1.32 d0[0], [r2]!
+            vld1.16 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d0, d2[0]
+            vmlal.s16 q9, d0, d2[1]
+            bne loop_e2h4l1
+        loop_e2h4_end:
+            vld1.32 q0, [r7]!
+            vcvt.f32.s32 q8, q8
+            vcvt.f32.s32 q9, q9
+            vmul.f32 q8, q8, q0
+            vmul.f32 q9, q9, q0
+            vcgt.f32 q1, q8, #0
+            vcgt.f32 q2, q9, #0
+            vbsl.f32 q1, q4, q5
+            vbsl.f32 q2, q4, q5
+            vadd.f32 q8, q8, q1
+            vadd.f32 q9, q9, q2
+            vcvt.s32.f32 q8, q8
+            vcvt.s32.f32 q9, q9
+            vmin.s32 q8, q8, q6
+            vmin.s32 q9, q9, q6
+            vmax.s32 q8, q8, q7
+            vmax.s32 q9, q9, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s32 d1, q9
+            vqmovn.s16 d0, q0
+            vst1.8 {d0}, [r0], r10
+            subs r11, r11, #1
+        bne loop_e2h4
+        pop {r0-r2, r10}
+        add r0, r0, #8
+        add r1, r1, #2
+    b loop_e2
+
+loop_e1:
+    cmp r10, #1
+    blt End
+    sub r10, r10, #1
+    ldr r5, [sp, #(push_registers_bytes + 4)]
+    ldr r6, [sp, #(push_registers_bytes + 8)]
+    ldr r7, [r4]
+    ldr r8, [r4, #4]
+    push {r0-r2, r10}
+    ldr r10, [r3, #20] // cStride
+    ldr lr, [r6], #4 // dataOffset
+    add r1, r1, lr
+    ldr r11, [r3, #16] // h
+    lsr r11, r11, #2 // hDiv4 (C4)
+    loop_e1h4:
+        vld1.32 q8, [r8]!
+        ldr r12, [r5], #4
+        cmp r12, #0
+        beq loop_e1h4_end
+        loop_e1h4l1:
+            vld1.32 d0[0], [r2]!
+            vld1.8 d2[0], [r1]
+            vmovl.s8 q0, d0
+            vmovl.s8 q1, d2
+            ldr lr, [r6], #4
+            add r1, r1, lr
+            subs r12, r12, #1
+
+            vmlal.s16 q8, d0, d2[0]
+            bne loop_e1h4l1
+        loop_e1h4_end:
+            vld1.32 q0, [r7]!
+            vcvt.f32.s32 q8, q8
+            vmul.f32 q8, q8, q0
+            vcgt.f32 q1, q8, #0
+            vbsl.f32 q1, q4, q5
+            vadd.f32 q8, q8, q1
+            vcvt.s32.f32 q8, q8
+            vmin.s32 q8, q8, q6
+            vmax.s32 q8, q8, q7
+            vqmovn.s32 d0, q8
+            vqmovn.s16 d0, q0
+            vst1.32 {d0[0]}, [r0], r10
+            subs r11, r11, #1
+        bne loop_e1h4
+        pop {r0-r2, r10}
+        add r0, r0, #4
+        add r1, r1, #1
+    b loop_e1
+
+End:
+vpop {q4-q7}
+pop {r4-r8, r10, r11, pc}
+
+#undef push_registers_bytes
+#undef sizeof_value
+#undef sizeof_value_lg2
+#undef sparse_blockoc
+
+#endif
+#endif
+
+

+ 15 - 11
source/backend/cpu/arm/arm64/MNNExpC8.S

@@ -12,28 +12,29 @@
 .text
 .align 5
 
-//void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8)
+//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
 asm_function MNNExpC8
 
-//x0: dest, x1:source, x2:parameters, x3:countC8
-
-ld1 {v0.4s, v1.4s}, [x2]
+//x0: dest, x1:source, x2: offset, x3:parameters, x4:countC8
+ldr w5, [x2, #0]
+ldr w6, [x2, #4]
+ld1 {v0.4s, v1.4s}, [x3]
 movi v2.4s, #23
 movi v3.4s, #87
 scvtf v3.4s, v3.4s
 fneg v4.4s, v3.4s
+dup v30.4s, w5
+dup v31.4s, w6
 
 Loop:
 
 ld1 {v16.4s, v17.4s}, [x1], #32
-
+fmul v16.4s, v16.4s, v30.4s
+fmul v17.4s, v17.4s, v30.4s
 fmin v16.4s, v16.4s, v3.4s
 fmin v17.4s, v17.4s, v3.4s
-fmax v16.4s, v16.4s, v4.4s
-fmax v17.4s, v17.4s, v4.4s
-
-fneg v18.4s, v16.4s
-fneg v19.4s, v17.4s
+fmax v18.4s, v16.4s, v4.4s
+fmax v19.4s, v17.4s, v4.4s
 
 fmul v16.4s, v18.4s, v0.s[1]
 fmul v17.4s, v19.4s, v0.s[1]
@@ -69,9 +70,12 @@ ushl v17.4s, v17.4s, v2.4s
 add v20.4s, v20.4s, v16.4s
 add v21.4s, v21.4s, v17.4s
 
+fadd v20.4s, v20.4s, v31.4s
+fadd v21.4s, v21.4s, v31.4s
+
 st1 {v20.4s, v21.4s}, [x0], #32
 
-subs x3, x3, #1
+subs x4, x4, #1
 bne Loop
 
 ret

+ 12 - 1
source/backend/cpu/arm/arm64/MNNPackC4ForMatMul_A.S

@@ -30,12 +30,16 @@
 asm_function MNNPackC4ForMatMul_A
 //void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el)
 //Auto: x0: dest, x1:sourceGroup, x2: info, x3:el
+
+str x19, [sp, #(-16 * 1)]
+
 ldr w10, [x2, #0] // number
 mov x4, #0
 mov x11, #0
 mov x6, #0
 ldr w4, [x2, #4] // eReal
 ldr w11, [x2, #8] // eDest
+mov x19, x11      // eP
 ldr w6, [x2, #12] // xOffset
 // xOffset -> xOffset * 4 * sizeof(float)
 // eReal -> eReal * 4 * sizeof(float)
@@ -65,10 +69,13 @@ mul x8, x12, x8
 add x0, x0, x7
 add x0, x0, x8
 
+cmp w19, #16
+bne E12Body
+
 Body:
 
 cmp w2, #16
-bne E12Body // caution: w2 = 16 or w2 = 12 is not continutious with LoopE1  at 'Right' segment. should not use 'blt'
+bne Right
     cmp w5, #4
     blt LoopE16L3
     LoopE16L4:
@@ -151,6 +158,8 @@ bne E12Body // caution: w2 = 16 or w2 = 12 is not continutious with LoopE1  at '
 b End
 
 E12Body:
+cmp w19, #12
+bne Right
 cmp w2, #12
 bne Right
     cmp w5, #4
@@ -322,6 +331,8 @@ add x1, x1, #8
 
 bne LoopNumber
 
+ldr x19, [sp, #(-16 * 1)]
+
 ret
 
 #endif

+ 10 - 10
source/backend/cpu/arm/arm64/MNNPackedSparseMatMulEpx4.S

@@ -91,7 +91,7 @@ loop_e16:
         mul x20, x20, x12
         add x19, x26, x20 // x19: c = blockC + ihpack * cStride
         cbz x6, load_e16h4_zero
-            ldr q16, [x24], #(4 * sizeof_value)
+            ldr q16, [x24], #(4 * sparse_blockoc)
             b load_e16h4_end
         load_e16h4_zero:
             movi v16.4s, #0000000000000000
@@ -204,7 +204,7 @@ loop_e16:
         add x19, x26, x20, lsl #sizeof_value_lg2 // x19: c = blockC + isubIndex
 
         cbz x6, load_e16h1_zero
-            ld1r {v16.4s}, [x24], #(sizeof_value)
+            ld1r {v16.4s}, [x24], #(4)
             b load_e16h1_end
         load_e16h1_zero:
             movi v16.4s, #0000000000000000
@@ -376,7 +376,7 @@ beq loop_e4
         mul x20, x20, x12
         add x19, x26, x20 // x19: c = blockC + ihpack * cStride
         cbz x6, load_e8h4_zero
-            ldr q16, [x24], #(4 * sizeof_value)
+            ldr q16, [x24], #(4 * sparse_blockoc)
             b load_e8h4_end
         load_e8h4_zero:
             movi v16.4s, #0000000000000000
@@ -450,7 +450,7 @@ beq loop_e4
         add x19, x26, x20, lsl #sizeof_value_lg2 // x19: c = blockC + isubIndex
 
         cbz x6, load_e8h1_zero
-            ld1r {v16.4s}, [x24], #(sizeof_value)
+            ld1r {v16.4s}, [x24], #(4)
             b load_e8h1_end
         load_e8h1_zero:
             movi v16.4s, #0000000000000000
@@ -524,7 +524,7 @@ beq loop_e2
         mul x20, x20, x12
         add x19, x26, x20 // x19: c = blockC + ihpack * cStride
         cbz x6, load_e4h4_zero
-            ldr q16, [x24], #(4 * sizeof_value)
+            ldr q16, [x24], #(4 * sparse_blockoc)
             b load_e4h4_end
         load_e4h4_zero:
             movi v16.4s, #0000000000000000
@@ -578,7 +578,7 @@ beq loop_e2
         add x19, x26, x20, lsl #sizeof_value_lg2 // x20: c = blockC + isubIndex
 
         cbz x6, load_e4h1_zero
-            ld1r {v16.4s}, [x24], #(sizeof_value)
+            ld1r {v16.4s}, [x24], #(4)
             b load_e4h1_end
         load_e4h1_zero:
             movi v16.4s, #0000000000000000
@@ -638,7 +638,7 @@ beq loop_e1
         mul x20, x20, x12
         add x19, x26, x20 // x19: c = blockC + ihpack * cStride
         cbz x6, load_e2h4_zero
-            ldr q16, [x24], #(4 * sizeof_value)
+            ldr q16, [x24], #(4 * sparse_blockoc)
             b load_e2h4_end
         load_e2h4_zero:
             movi v16.4s, #0000000000000000
@@ -684,7 +684,7 @@ beq loop_e1
         add x19, x26, x20, lsl #sizeof_value_lg2 // x20: c = blockC + isubIndex
 
         cbz x6, load_e2h1_zero
-            ld1r {v16.2s}, [x24], #(sizeof_value)
+            ld1r {v16.2s}, [x24], #(sparse_blockoc)
             b load_e2h1_end
         load_e2h1_zero:
             movi v16.4s, #0000000000000000
@@ -743,7 +743,7 @@ beq loop_end
         mul x20, x20, x12
         add x19, x26, x20 // x19: c = blockC + ihpack * cStride
         cbz x6, load_e1h4_zero
-            ldr q16, [x24], #(4 * sizeof_value)
+            ldr q16, [x24], #(4 * sparse_blockoc)
             b load_e1h4_end
         load_e1h4_zero:
             movi v16.4s, #0000000000000000
@@ -785,7 +785,7 @@ beq loop_end
         add x19, x26, x20, lsl #sizeof_value_lg2 // x20: c = blockC + isubIndex
 
         cbz x6, load_e1h1_zero
-            ld1 {v16.s}[0], [x24], #(sizeof_value)
+            ld1 {v16.s}[0], [x24], #(4)
             b load_e1h1_end
         load_e1h1_zero:
             movi v16.4s, #0000000000000000

+ 520 - 0
source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S

@@ -0,0 +1,520 @@
+//
+//  MNNPackedSparseQuantMatMulEpx1.S
+//  MNN
+//
+//  Created by MNN on 2021/06/20.
+//  Copyright © 2018-2021 Alibaba Group Holding Limited
+//
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+#define sizeof_value 1
+#define sizeof_value_lg2 0
+#define sparse_blockoc 4
+
+.text
+.align 5
+// 16 * 4 MatMul
+asm_function MNNPackedSparseQuantMatMulEpx1
+// void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam,
+// const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
+// x0: C, x1:A, x2:B, x3:sparseQuantParam, x4:QuanPostTreatParameters, x5:NNZMap, x6:dataOffsetMap
+
+stp x19, x20, [sp, #(-16 * 1)]
+stp x21, x22, [sp, #(-16 * 2)]
+stp x23, x24, [sp, #(-16 * 3)]
+stp x25, x26, [sp, #(-16 * 4)]
+stp x27, x28, [sp, #(-16 * 5)]
+
+stp d8,  d9,  [sp, #(-16 * 6)]
+stp d10, d11, [sp, #(-16 * 7)]
+stp d12, d13, [sp, #(-16 * 8)]
+str d14, [sp, #(-16 * 9)]
+
+ldp x13, x10, [x3, #16]     // x13: aStride, x10: l
+ldp x11, x12, [x3, #32]     // x11: h, x12: cStride
+ldp x3, x9, [x3]            // x3: eSize, x9: eP
+
+mov x8, x6                  // x8: dataOffsetMap
+mov x7, x5                  // x7: NNZMap
+ldp x24, x6, [x4], #16      // x5: scale , x6: bias
+lsr x14, x11, #2
+lsl x14, x14, #2            // x14:  (h / 4) * 4
+ld2r {v13.4s, v14.4s}, [x4] // first two elements of x4 are pointers, 'max, min ' locate at [2], [3]
+
+
+//x0:C,
+//x1:A,
+//x2:B,
+//x3:eSize,
+//x4:parameter,      // free
+//x5:postParameters, // free
+//x6:bias
+// x7, x15: unsigned int* NNZMap,
+// x8, x26: int* dataOffsetMap
+// x9: eP,
+// x10: l             // free
+// x11: h,
+// x12: cStride with sizeof
+// x13: aStride with sizeof
+// x14: (h / 4) * 4
+// x24: scale
+
+// v0-v3: A
+// v4:  B
+// v13: maxValue
+// v14: minValue
+// v16-v31: C
+// sparse_blockoc = 4
+
+
+// x4 as ie
+// x5 as ih
+// w20 as il
+
+mov x10, x2
+mov x4, xzr
+cmp x9, x3
+bgt loop_e8
+
+loop_e16:
+
+    mov x26, x8
+    ldrsw x27, [x26], #4
+    add x1, x1, x27, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+    mov x2, x10
+    mov x15, x7
+    add x27, x0, x4, lsl #(sizeof_value_lg2 + 2) // float* blockC = C + (ie << 2);
+
+    mov x5, xzr
+    mov x28, x6 // bias
+    mov x25, x24 // scale
+    loop_e16h1:
+
+        lsr x21, x5, #2
+        and x20, x5, #0x03 // NC4HW4
+        mul x21, x21, x12
+        add x19, x27, x20, lsl #sizeof_value_lg2
+        add x19, x19, x21
+        cbz x6, load_e16h1_zero
+            ld1r {v16.4s}, [x28], #(4)
+            b load_e16h1_end
+        load_e16h1_zero:
+            movi v16.4s, #0000000000000000
+
+        load_e16h1_end:
+        ldr w20, [x15], #4
+        mov v17.16b, v16.16b
+        mov v18.16b, v16.16b
+        mov v19.16b, v16.16b
+        cbz w20, loop_e16h1l1_end
+
+        loop_e16h1l1:
+
+          ldr q0, [x1]
+          ld1r {v1.16b}, [x2], #(sizeof_value)
+          ldrsw x21, [x26], #4
+          subs w20, w20, #1
+          add x1, x1, x21, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+
+            smull v5.8h, v0.8b, v1.8b
+            smull2 v9.8h, v0.16b, v1.16b
+
+            saddw v16.4s, v16.4s, v5.4h
+            saddw v18.4s, v18.4s, v9.4h
+            saddw2 v17.4s, v17.4s, v5.8h
+            saddw2 v19.4s, v19.4s, v9.8h
+
+          bne loop_e16h1l1
+
+    loop_e16h1l1_end:
+
+    cbz x24, clamp_noscale_e16h1
+        // deal with scale
+        ldr s0, [x25], #(4)
+        scvtf v16.4s, v16.4s
+        scvtf v17.4s, v17.4s
+        scvtf v18.4s, v18.4s
+        scvtf v19.4s, v19.4s
+        fmul v16.4s, v16.4s, v0.s[0]
+        fmul v17.4s, v17.4s, v0.s[0]
+        fmul v18.4s, v18.4s, v0.s[0]
+        fmul v19.4s, v19.4s, v0.s[0]
+        fcvtas v16.4s, v16.4s
+        fcvtas v17.4s, v17.4s
+        fcvtas v18.4s, v18.4s
+        fcvtas v19.4s, v19.4s
+
+    clamp_noscale_e16h1:
+    smin v16.4s, v16.4s, v13.4s
+    smin v17.4s, v17.4s, v13.4s
+    smin v18.4s, v18.4s, v13.4s
+    smin v19.4s, v19.4s, v13.4s
+    add x5, x5, #1
+    smax v16.4s, v16.4s, v14.4s
+    smax v17.4s, v17.4s, v14.4s
+    smax v18.4s, v18.4s, v14.4s
+    smax v19.4s, v19.4s, v14.4s
+
+    sqxtn v0.4h, v16.4s
+    sqxtn2 v0.8h, v17.4s
+    sqxtn v1.4h, v18.4s
+    sqxtn2 v1.8h, v19.4s
+
+    sqxtn v16.8b, v0.8h
+    sqxtn2 v16.16b, v1.8h
+
+    mov x23, #(4 * 4 * sizeof_value)
+    add x20, x19, #(4 * sizeof_value)
+    add x21, x19, #(8 * sizeof_value)
+    add x22, x20, #(8 * sizeof_value)
+    cmp x5, x11
+
+    st1 {v16.b}[0], [x19], x23 // st1 donot support immediate increasement other than sizeof stored element
+    st1 {v16.b}[1], [x20], x23
+    st1 {v16.b}[2], [x21], x23
+    st1 {v16.b}[3], [x22], x23
+    st1 {v16.b}[4], [x19], x23
+    st1 {v16.b}[5], [x20], x23
+    st1 {v16.b}[6], [x21], x23
+    st1 {v16.b}[7], [x22], x23
+    st1 {v16.b}[8], [x19], x23
+    st1 {v16.b}[9], [x20], x23
+    st1 {v16.b}[10], [x21], x23
+    st1 {v16.b}[11], [x22], x23
+    st1 {v16.b}[12], [x19]
+    st1 {v16.b}[13], [x20]
+    st1 {v16.b}[14], [x21]
+    st1 {v16.b}[15], [x22]
+
+    blt loop_e16h1
+
+    loop_e16h_end:
+
+    add x4, x4, x9
+    add x1, x1, x13
+
+    add x5, x4, x9
+    cmp x5, x3
+    ble loop_e16
+
+loop_e8:
+ands x5, x3, #0x08
+beq loop_e4
+
+    mov x26, x8
+    ldrsw x27, [x26], #4
+    add x1, x1, x27, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+    mov x2, x10
+    mov x15, x7
+    add x27, x0, x4, lsl #(sizeof_value_lg2 + 2) // float* blockC = C + (ie << 2);
+
+    mov x5, xzr
+    mov x28, x6 // bias
+    mov x25, x24 // scale
+
+    loop_e8h1:
+        lsr x21, x5, #2
+        and x20, x5, #0x03 // NC4HW4
+        mul x21, x21, x12
+        add x19, x27, x20, lsl #sizeof_value_lg2
+        add x19, x19, x21
+
+        cbz x6, load_e8h1_zero
+            ld1r {v16.4s}, [x28], #(4)
+            b load_e8h1_end
+        load_e8h1_zero:
+            movi v16.4s, #0000000000000000
+
+        load_e8h1_end:
+        ldr w20, [x15], #4
+        mov v17.16b, v16.16b
+        cbz w20, loop_e8h1l1_end
+
+        loop_e8h1l1:
+          ldr d0, [x1]
+          ld1r {v1.8b}, [x2], #(sizeof_value)
+          ldrsw x21, [x26], #4
+          subs w20, w20, #1
+          add x1, x1, x21, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+          smull v5.8h, v0.8b, v1.8b
+          saddw v16.4s, v16.4s, v5.4h
+          saddw2 v17.4s, v17.4s, v5.8h
+          bne loop_e8h1l1
+
+    loop_e8h1l1_end:
+    cbz x24, clamp_noscale_e8h1
+        // deal with scale
+        ldr s0, [x25], #(4)
+        scvtf v16.4s, v16.4s
+        scvtf v17.4s, v17.4s
+        fmul v16.4s, v16.4s, v0.s[0]
+        fmul v17.4s, v17.4s, v0.s[0]
+        fcvtas v16.4s, v16.4s
+        fcvtas v17.4s, v17.4s
+    clamp_noscale_e8h1:
+    smin v16.4s, v16.4s, v13.4s
+    smin v17.4s, v17.4s, v13.4s
+    add x5, x5, #1
+    smax v16.4s, v16.4s, v14.4s
+    smax v17.4s, v17.4s, v14.4s
+
+    sqxtn v0.4h, v16.4s
+    sqxtn2 v0.8h, v17.4s
+    sqxtn v16.8b, v0.8h
+
+    mov x23, #(4 * 4 * sizeof_value)
+    add x20, x19, #(4 * sizeof_value)
+    add x21, x19, #(8 * sizeof_value)
+    add x22, x20, #(8 * sizeof_value)
+
+    cmp x5, x11
+    st1 {v16.b}[0], [x19], X23 // st1 donot support immediate increasement other than sizeof stored element
+    st1 {v16.b}[1], [x20], X23
+    st1 {v16.b}[2], [x21], X23
+    st1 {v16.b}[3], [x22], X23
+    st1 {v16.b}[4], [x19]
+    st1 {v16.b}[5], [x20]
+    st1 {v16.b}[6], [x21]
+    st1 {v16.b}[7], [x22]
+    blt loop_e8h1
+
+    loop_e8h_end:
+
+    add x4, x4, #8 // e8
+    add x1, x1, #(8 * sizeof_value) // Has not exceed one aStride, just 8
+
+loop_e4:
+ands x5, x3, #0x04
+beq loop_e2
+
+    mov x26, x8
+    ldrsw x27, [x26], #4
+    add x1, x1, x27, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+    mov x2, x10
+    mov x15, x7
+    add x27, x0, x4, lsl #(sizeof_value_lg2 + 2) // float* blockC = C + (ie << 2);
+    mov x5, xzr
+    mov x28, x6 // bias
+    mov x25, x24 // scale
+
+    loop_e4h1:
+        lsr x21, x5, #2
+        and x20, x5, #0x03 // NC4HW4
+        mul x21, x21, x12
+        add x19, x27, x20, lsl #sizeof_value_lg2
+        add x19, x19, x21
+
+        cbz x6, load_e4h1_zero
+            ld1r {v16.4s}, [x28], #(4)
+            b load_e4h1_end
+        load_e4h1_zero:
+            movi v16.4s, #0000000000000000
+
+        load_e4h1_end:
+        ldr w20, [x15], #4
+        cbz w20, loop_e4h1l1_end
+
+        loop_e4h1l1:
+
+          ldr s0, [x1]
+          ld1r {v1.8b}, [x2], #(sizeof_value) // try 4b
+          ldrsw x21, [x26], #4
+          subs w20, w20, #1
+          add x1, x1, x21, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+          smull v5.8h, v0.8b, v1.8b
+          saddw v16.4s, v16.4s, v5.4h
+          bne loop_e4h1l1
+
+    loop_e4h1l1_end:
+    cbz x24, clamp_noscale_e4h1
+        // deal with scale
+        ldr s0, [x25], #(4)
+        scvtf v16.4s, v16.4s
+        fmul v16.4s, v16.4s, v0.s[0]
+        fcvtas v16.4s, v16.4s
+    clamp_noscale_e4h1:
+    smin v16.4s, v16.4s, v13.4s
+    add x5, x5, #1
+    smax v16.4s, v16.4s, v14.4s
+
+    sqxtn v0.4h, v16.4s
+    sqxtn v16.8b, v0.8h // 4b is valid
+
+    add x20, x19, #(4 * sizeof_value)
+    add x21, x19, #(8 * sizeof_value)
+    add x22, x20, #(8 * sizeof_value)
+
+    cmp x5, x11
+    st1 {v16.b}[0], [x19] // st1 donot support immediate increasement other than sizeof stored element
+    st1 {v16.b}[1], [x20]
+    st1 {v16.b}[2], [x21]
+    st1 {v16.b}[3], [x22]
+    blt loop_e4h1
+
+    loop_e4h_end:
+
+    add x4, x4, #4 // e4
+    add x1, x1, #(4 * sizeof_value) // Has not exceed one aStride, just 4
+
+
+loop_e2:
+ands x5, x3, #0x02
+beq loop_e1
+
+    mov x26, x8
+    ldrsw x27, [x26], #4
+    add x1, x1, x27, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+    mov x2, x10
+    mov x15, x7
+    add x27, x0, x4, lsl #(sizeof_value_lg2 + 2) // float* blockC = C + (ie << 2);
+    mov x5, xzr
+    mov x28, x6 // bias
+    mov x25, x24 // scale
+    cbz x14, loop_e2h1
+
+    loop_e2h1:
+        lsr x21, x5, #2
+        and x20, x5, #0x03 // NC4HW4
+        mul x21, x21, x12
+        add x19, x27, x20, lsl #sizeof_value_lg2
+        add x19, x19, x21
+
+        cbz x6, load_e2h1_zero
+            ld1r {v16.2s}, [x28], #(4)
+            b load_e2h1_end
+        load_e2h1_zero:
+            movi v16.4s, #0000000000000000
+        load_e2h1_end:
+        ldr w20, [x15], #4
+        cbz w20, loop_e2h1l1_end
+        loop_e2h1l1:
+
+          ld1 {v0.h}[0], [x1]
+          ld1r {v1.8b}, [x2], #(sizeof_value) // try 2b
+          ldrsw x21, [x26], #4
+          subs w20, w20, #1
+          add x1, x1, x21, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+          smull v5.8h, v0.8b, v1.8b // only 2b valid
+          saddw v16.4s, v16.4s, v5.4h
+          bne loop_e2h1l1
+
+    loop_e2h1l1_end:
+
+        cbz x24, clamp_noscale_e2h1
+        // deal with scale
+        ldr s0, [x25], #(4)
+        scvtf v16.2s, v16.2s
+        fmul v16.2s, v16.2s, v0.s[0]
+        fcvtas v16.2s, v16.2s
+    clamp_noscale_e2h1:
+    smin v16.2s, v16.2s, v13.2s
+    add x5, x5, #1
+    smax v16.2s, v16.2s, v14.2s
+    add x20, x19, #(4 * sizeof_value)
+    sqxtn v0.4h, v16.4s
+    sqxtn v16.8b, v0.8h // 2h -> 2b is valid
+    cmp x5, x11
+    st1 {v16.b}[0], [x19] // st1 donot support immediate increasement other than sizeof stored element
+    st1 {v16.b}[1], [x20]
+    blt loop_e2h1
+
+    loop_e2h_end:
+    add x4, x4, #2 // e2
+    add x1, x1, #(2 * sizeof_value) // Has not exceed one aStride, just 2
+
+
+loop_e1:
+ands x5, x3, #0x01
+beq loop_end
+
+    mov x26, x8
+    ldrsw x27, [x26], #4
+    add x1, x1, x27, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+
+    mov x2, x10
+    mov x15, x7
+    add x27, x0, x4, lsl #(sizeof_value_lg2 + 2) // float* blockC = C + (ie << 2);
+
+    mov x5, xzr
+    mov x28, x6 // bias
+    mov x25, x24 // scale
+    loop_e1h1:
+        lsr x21, x5, #2
+        and x20, x5, #0x03 // NC4HW4
+        mul x21, x21, x12
+        add x19, x27, x20, lsl #sizeof_value_lg2
+        add x19, x19, x21
+
+        cbz x6, load_e1h1_zero
+            ld1 {v16.s}[0], [x28], #(4)
+            b load_e1h1_end
+        load_e1h1_zero:
+            movi v16.4s, #0000000000000000
+
+        load_e1h1_end:
+        ldr w20, [x15], #4
+
+        cbz w20, loop_e1h1l1_end
+
+        loop_e1h1l1:
+
+          ld1 {v0.b}[0], [x1]
+          ld1 {v1.b}[0], [x2], #(sizeof_value)
+          ldrsw x21, [x26], #4
+          subs w20, w20, #1
+          add x1, x1, x21, lsl #sizeof_value_lg2 // a += diff * sizeof(float)
+          smull v5.8h, v0.8b, v1.8b // only 1h valid
+          saddw v16.4s, v16.4s, v5.4h // only 1s is valid
+          bne loop_e1h1l1
+
+    loop_e1h1l1_end:
+
+    cbz x24, clamp_noscale_e1h1
+     // deal with scale
+      ldr s0, [x25], #(4)
+      scvtf s16, s16
+      fmul s16, s16, v0.s[0]
+      fcvtas s16, s16
+    clamp_noscale_e1h1:
+
+    smin v16.2s, v16.2s, v13.2s
+    add x5, x5, #1
+    smax v16.2s, v16.2s, v14.2s
+    sqxtn v0.4h, v16.4s
+    sqxtn v16.8b, v0.8h // 1b is valid
+    cmp x5, x11
+    st1 {v16.b}[0], [x19]
+    blt loop_e1h1
+
+    loop_e1h_end:
+    add x4, x4, #1 // e1
+
+loop_end:
+
+ldp x19, x20, [sp, #(-16 * 1)]
+ldp x21, x22, [sp, #(-16 * 2)]
+ldp x23, x24, [sp, #(-16 * 3)]
+ldp x25, x26, [sp, #(-16 * 4)]
+ldp x27, x28, [sp, #(-16 * 5)]
+ldp d8,  d9,  [sp, #(-16 * 6)]
+ldp d10, d11, [sp, #(-16 * 7)]
+ldp d12, d13, [sp, #(-16 * 8)]
+ldr d14, [sp, #(-16 * 9)]
+
+ret
+
+#undef sizeof_value
+#undef sizeof_value_lg2
+#undef sparse_blockoc
+
+
+#endif

File diff suppressed because it is too large
+ 1086 - 0
source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S


+ 12 - 8
source/backend/cpu/arm/arm64/MNNSoftmax.S

@@ -14,6 +14,8 @@
 
 //void MNNSoftmax(float* dest, const float* source, size_t countC8)
 asm_function MNNSoftmax
+    stp x19, x20, [sp, #(-16 * 1)]
+    stp x21, x22, [sp, #(-16 * 2)]
     sxtw    x8, w2
     lsr     w9, w2, #2
     and     x8, x8, #-4
@@ -123,24 +125,24 @@ Loop_13:
     cmp     w8, w2
     fmov    s2, #1.0
     b.hs    Loop_16
-    lsl     x18, x8, #2
+    lsl     x21, x8, #2
     mov     w12, #29208
     mov     w14, #34953
     mov     w15, #43691
-    mov     w16, #43691
+    mov     w19, #43691
     mov     w10, #-1028784128
     mov     w11, #1118699520
     movk    w12, #16177, lsl #16
     mov     w13, #1065353216
     movk    w14, #15368, lsl #16
     movk    w15, #15658, lsl #16
-    movk    w16, #15914, lsl #16
-    add     x17, x1, x18
-    add     x18, x0, x18
+    movk    w19, #15914, lsl #16
+    add     x20, x1, x21
+    add     x21, x0, x21
     fmov    s3, #0.5
     mov     w1, w8
 Loop_15:
-    ldr     s4, [x17], #4
+    ldr     s4, [x20], #4
     fmov    s5, w10
     fmov    s6, w11
     fmov    s7, w12
@@ -156,7 +158,7 @@ Loop_15:
     fmov    s7, w15
     fmul    s5, s4, s5
     fadd    s5, s5, s7
-    fmov    s6, w16
+    fmov    s6, w19
     fmul    s5, s4, s5
     fadd    s5, s5, s6
     fmul    s5, s4, s5
@@ -170,7 +172,7 @@ Loop_15:
     add     w1, w1, #1
     fmul    s4, s4, s7
     cmp     w1, w2
-    str     s4, [x18], #4
+    str     s4, [x21], #4
     fadd    s1, s1, s4
     b.lo    Loop_15
 Loop_16:
@@ -196,6 +198,8 @@ Loop_19:
     str     s1, [x9], #4
     b.lo    Loop_21
 Loop_22:
+    ldp x19, x20, [sp, #(-16 * 1)]
+    ldp x21, x22, [sp, #(-16 * 2)]
     ret
 #endif
 

+ 1 - 0
source/backend/cpu/bf16/BF16Functions.cpp

@@ -734,6 +734,7 @@ bool BF16Functions::init() {
     gInstance->MNNAddC4WithStride = _MNNAddC4WithStride;
     gInstance->chooseWinoDestTransform = (decltype(gInstance->chooseWinoDestTransform))(WinogradFunctionHalf::chooseDestTransform);
     gInstance->chooseWinoSourceTransform = (decltype(gInstance->chooseWinoSourceTransform))(WinogradFunctionHalf::chooseSourceTransform);
+    gInstance->chooseWinoSourceTransformPack =  (decltype(gInstance->chooseWinoSourceTransformPack))(WinogradFunctionHalf::chooseWinoSourceTransformPack);
     gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise;
     gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise;
     gInstance->MNNSelectBinaryFunctionForFloat = BF16BinaryFloatSelect;

+ 10 - 7
source/backend/cpu/bf16/BF16Unary.cpp

@@ -89,19 +89,22 @@ struct _Exp {
     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
         auto out = (float*)outRaw;
         auto inp = (const float*)inpRaw;
-        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-        MNNExp(out, out, realSize);
+        float offset[2] = {
+            1.0f,
+            0.0f
+        };
+        MNNExp(out, inp, offset, realSize);
     }
 };
 struct _ExpM1 {
     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
         auto out = (float*)outRaw;
         auto inp = (const float*)inpRaw;
-        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
-        MNNExp(out, out, realSize);
-        for (int i=0; i<realSize; ++i) {
-            out[i] = out[i] - 1.0f;
-        }
+        float offset[2] = {
+            1.0f,
+            -1.0f
+        };
+        MNNExp(out, inp, offset, realSize);
     }
 };
 

+ 110 - 0
source/backend/cpu/bf16/VecHalf.hpp

@@ -115,6 +115,20 @@ struct VecHalf {
         }
         return dst;
     }
+    static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
+        VecType source[4] = {vec0, vec1, vec2, vec3};
+        for (int i = 0; i < N; ++i) {
+            vec0.value[i] = source[i % 4].value[i >> 2];
+            vec1.value[i] = source[i % 4].value[(i + N)>> 2];
+            vec2.value[i] = source[i % 4].value[(i + 2 * N)>> 2];
+            vec3.value[i] = source[i % 4].value[(i + 3 * N)>> 2];
+        }
+    }
+
+    static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
+
+        MNN_ASSERT(false);
+    }
 };
 
 #if defined(MNN_USE_SSE)
@@ -228,6 +242,22 @@ struct VecHalf<4> {
         VecType dst = { _mm_min_ps(v1.value, v2.value) };
         return dst;
     }
+    static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
+        __m128 tmp3, tmp2, tmp1, tmp0;
+        tmp0   = _mm_unpacklo_ps((vec0.value), (vec1.value));
+        tmp2   = _mm_unpacklo_ps((vec2.value), (vec3.value));
+        tmp1   = _mm_unpackhi_ps((vec0.value), (vec1.value));
+        tmp3   = _mm_unpackhi_ps((vec2.value), (vec3.value));
+        vec0.value = _mm_movelh_ps(tmp0, tmp2);
+        vec1.value = _mm_movehl_ps(tmp2, tmp0);
+        vec2.value = _mm_movelh_ps(tmp1, tmp3);
+        vec3.value = _mm_movehl_ps(tmp3, tmp1);
+    }
+
+    // x86 VecHalf transpose12 unused in any case
+    static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
+        MNN_ASSERT(false);
+    }
 };
 #endif
 
@@ -311,6 +341,86 @@ struct VecHalf<4> {
         VecType dst = { vminq_f32(v1.value, v2.value) };
         return dst;
     }
+    static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
+#ifdef __aarch64__
+        auto m0 = vtrn1q_s32(vec0.value, vec1.value);
+        auto m1 = vtrn2q_s32(vec0.value, vec1.value);
+        auto m2 = vtrn1q_s32(vec2.value, vec3.value);
+        auto m3 = vtrn2q_s32(vec2.value, vec3.value);
+        vec0.value = vtrn1q_s64(m0, m2);
+        vec1.value = vtrn1q_s64(m1, m3);
+        vec2.value = vtrn2q_s64(m0, m2);
+        vec3.value = vtrn2q_s64(m1, m3);
+#else
+        auto m0m1 = vtrnq_s32(vec0.value, vec1.value);
+        auto m2m3 = vtrnq_s32(vec2.value, vec3.value);
+        vec0.value = m0m1.val[0];
+        vec1.value = m0m1.val[1];
+        vec2.value = m2m3.val[0];
+        vec3.value = m2m3.val[1];
+        vec0.value = vsetq_lane_s64(vgetq_lane_s64(m2m3.val[0], 0), vec0.value, 1);
+        vec1.value = vsetq_lane_s64(vgetq_lane_s64(m2m3.val[1], 0), vec1.value, 1);
+        vec2.value = vsetq_lane_s64(vgetq_lane_s64(m0m1.val[0], 1), vec2.value, 0);
+        vec3.value = vsetq_lane_s64(vgetq_lane_s64(m0m1.val[1], 1), vec3.value, 0);
+        /*
+        generated arm32 assembly code is almost the same as:
+            vtrn.32 d0, d2
+            vtrn.32 d1, d3
+            vtrn.32 d4, d6
+            vtrn.32 d5, d7
+            vswp d1, d4
+            vswp d3, d6
+        */
+
+#endif
+    }
+    static inline void transpose4(int16x4_t& vec0, int16x4_t& vec1, int16x4_t& vec2, int16x4_t& vec3) {
+        auto trans0 = vtrn_s16(vec0, vec1);
+        auto m0 = trans0.val[0];
+        auto m1 = trans0.val[1];
+        auto trans1 = vtrn_s16(vec2, vec3);
+        auto m2 = trans1.val[0];
+        auto m3 = trans1.val[1];
+        auto trans2 = vtrn_s32(m0, m2);
+        vec0 = trans2.val[0];
+        vec2 = trans2.val[1];
+        auto trans3 = vtrn_s32(m1, m3);
+        vec1 = trans3.val[0];
+        vec3 = trans3.val[1];
+
+    }
+    static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
+        auto s0  = vld1_s16(srcPtr + 0 * packCUnit);
+        auto s3  = vld1_s16(srcPtr + 1 * packCUnit);
+        auto s6  = vld1_s16(srcPtr + 2 * packCUnit);
+        auto s9  = vld1_s16(srcPtr + 3 * packCUnit);
+        auto s1  = vld1_s16(srcPtr + 4 * packCUnit);
+        auto s4  = vld1_s16(srcPtr + 5 * packCUnit);
+        auto s7  = vld1_s16(srcPtr + 6 * packCUnit);
+        auto s10 = vld1_s16(srcPtr + 7 * packCUnit);
+        auto s2  = vld1_s16(srcPtr + 8 * packCUnit);
+        auto s5  = vld1_s16(srcPtr + 9 * packCUnit);
+        auto s8  = vld1_s16(srcPtr + 10 * packCUnit);
+        auto s11 = vld1_s16(srcPtr + 11 * packCUnit);
+
+        transpose4(s0, s3, s6, s9);
+        transpose4(s1, s4, s7, s10);
+        transpose4(s2, s5, s8, s11);
+
+        vst1_s16(srcPtr + 0 * packCUnit, s0);
+        vst1_s16(srcPtr + 1 * packCUnit, s1);
+        vst1_s16(srcPtr + 2 * packCUnit, s2);
+        vst1_s16(srcPtr + 3 * packCUnit, s3);
+        vst1_s16(srcPtr + 4 * packCUnit, s4);
+        vst1_s16(srcPtr + 5 * packCUnit, s5);
+        vst1_s16(srcPtr + 6 * packCUnit, s6);
+        vst1_s16(srcPtr + 7 * packCUnit, s7);
+        vst1_s16(srcPtr + 8 * packCUnit, s8);
+        vst1_s16(srcPtr + 9 * packCUnit, s9);
+        vst1_s16(srcPtr + 10 * packCUnit, s10);
+        vst1_s16(srcPtr + 11 * packCUnit, s11);
+
+    }
 };
 #endif
 

+ 356 - 0
source/backend/cpu/bf16/WinogradOptFunctionHalf.cpp

@@ -9,11 +9,342 @@
 #include "WinogradOptFunctionHalf.hpp"
 #include <cstring>
 #include <memory>
+#include <map>
 #include "core/Macro.h"
 #include "VecHalf.hpp"
 using BFVec4 = MNN::Math::VecHalf<4>;
+using VecType = BFVec4;
+using ElementType = int16_t;
+
+// to be optimized into VecType::transpose12
+#define TRANSPOSE_12X4_SAVE()                             \
+    VecType s0  = VecType::load(srcPtr + 0 * packCUnit);  \
+    VecType s3  = VecType::load(srcPtr + 1 * packCUnit);  \
+    VecType s6  = VecType::load(srcPtr + 2 * packCUnit);  \
+    VecType s9  = VecType::load(srcPtr + 3 * packCUnit);  \
+    VecType s1  = VecType::load(srcPtr + 4 * packCUnit);  \
+    VecType s4  = VecType::load(srcPtr + 5 * packCUnit);  \
+    VecType s7  = VecType::load(srcPtr + 6 * packCUnit);  \
+    VecType s10 = VecType::load(srcPtr + 7 * packCUnit);  \
+    VecType s2  = VecType::load(srcPtr + 8 * packCUnit);  \
+    VecType s5  = VecType::load(srcPtr + 9 * packCUnit);  \
+    VecType s8  = VecType::load(srcPtr + 10 * packCUnit); \
+    VecType s11 = VecType::load(srcPtr + 11 * packCUnit); \
+    VecType::transpose4(s0, s3, s6, s9);                  \
+    VecType::transpose4(s1, s4, s7, s10);                 \
+    VecType::transpose4(s2, s5, s8, s11);                 \
+    VecType::save(srcPtr + 0 * packCUnit, s0);            \
+    VecType::save(srcPtr + 1 * packCUnit, s1);            \
+    VecType::save(srcPtr + 2 * packCUnit, s2);            \
+    VecType::save(srcPtr + 3 * packCUnit, s3);            \
+    VecType::save(srcPtr + 4 * packCUnit, s4);            \
+    VecType::save(srcPtr + 5 * packCUnit, s5);            \
+    VecType::save(srcPtr + 6 * packCUnit, s6);            \
+    VecType::save(srcPtr + 7 * packCUnit, s7);            \
+    VecType::save(srcPtr + 8 * packCUnit, s8);            \
+    VecType::save(srcPtr + 9 * packCUnit, s9);            \
+    VecType::save(srcPtr + 10 * packCUnit, s10);          \
+    VecType::save(srcPtr + 11 * packCUnit, s11);
 
 namespace MNN {
+
+
+static void _sourceTransformUnit4x4Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+    // register number: (srcUnit + 1) * EPack/packCUnit
+    constexpr int Nh = 4; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        // TRANSPOSE_12X4_SAVE();
+        VecType::transpose12(srcPtr, packCUnit);
+        srcPtr += loadTransposeStride;
+    }
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+        // source transform D * B. register number : srcUnit * (EPack/4 + 1)
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        // dstStep =  ePack * pack * ic_4
+        auto ep0 = s00 - s20;
+        auto ep1 = s01 - s21;
+        auto ep2 = s02 - s22;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 + s20;
+        ep1 = s11 + s21;
+        ep2 = s12 + s22;
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 - s10;
+        ep1 = s21 - s11;
+        ep2 = s22 - s12;
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 - s10;
+        ep1 = s31 - s11;
+        ep2 = s32 - s12;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        // VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, s00);
+        // VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, s01);
+        // VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, s02);
+
+        // VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, s10);
+        // VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, s11);
+        // VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, s12);
+
+        // VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, s20);
+        // VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, s21);
+        // VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, s22);
+
+        // VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, s30);
+        // VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, s31);
+        // VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, s32);
+
+        // MNN_PRINT("\nwinograd in BT*D*B, iNh:0-3, i4c:%d\n", i4c);
+        // formatMatrix(dstPtr + 0 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 1 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 2 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 3 * dstStep , {ePack});
+
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
+}
+
+static void _sourceTransformUnit8x8Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/packCUnit = 27
+    // todo: impliment
+    constexpr int Nh = 8; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        VecType::transpose12(srcPtr, packCUnit);
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s40 = VecType::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        VecType s41 = VecType::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        VecType s42 = VecType::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s50 = VecType::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        VecType s51 = VecType::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        VecType s52 = VecType::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s60 = VecType::load(srcPtr + 6 * loadTransposeStride + 0 * packCUnit);
+        VecType s61 = VecType::load(srcPtr + 6 * loadTransposeStride + 1 * packCUnit);
+        VecType s62 = VecType::load(srcPtr + 6 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s70 = VecType::load(srcPtr + 7 * loadTransposeStride + 0 * packCUnit);
+        VecType s71 = VecType::load(srcPtr + 7 * loadTransposeStride + 1 * packCUnit);
+        VecType s72 = VecType::load(srcPtr + 7 * loadTransposeStride + 2 * packCUnit);
+
+
+        // to-try: reorder complicated commpute of 8x8
+        auto ep0 = s00 * 36.f - s20 * 49.f + s40 * 14.f - s60;
+        auto ep1 = s01 * 36.f - s21 * 49.f + s41 * 14.f - s61;
+        auto ep2 = s02 * 36.f - s22 * 49.f + s42 * 14.f - s62;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * 36.f - (s30 + s40) * 13.f + (s50 + s60);
+        ep1 = (s11 + s21) * 36.f - (s31 + s41) * 13.f + (s51 + s61);
+        ep2 = (s12 + s22) * 36.f - (s32 + s42) * 13.f + (s52 + s62);
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s20 - s10) * 36.f + (s30 - s40) * 13.f + (s60 - s50);
+        ep1 = (s21 - s11) * 36.f + (s31 - s41) * 13.f + (s61 - s51);
+        ep2 = (s22 - s12) * 36.f + (s32 - s42) * 13.f + (s62 - s52);
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 18.f + s20 * 9.f - s30 * 20.f - s40 * 10.f + s50 * 2.f + s60;
+        ep1 = s11 * 18.f + s21 * 9.f - s31 * 20.f - s41 * 10.f + s51 * 2.f + s61;
+        ep2 = s12 * 18.f + s22 * 9.f - s32 * 20.f - s42 * 10.f + s52 * 2.f + s62;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 9.f - s10 * 18.f + s30 * 20.f - s40 * 10.f - s50 * 2.f + s60;
+        ep1 = s21 * 9.f - s11 * 18.f + s31 * 20.f - s41 * 10.f - s51 * 2.f + s61;
+        ep2 = s22 * 9.f - s12 * 18.f + s32 * 20.f - s42 * 10.f - s52 * 2.f + s62;
+        VecType::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 12.f + s20 * 4.f - s30 * 15.f - s40 * 5.f + s50 * 3.f + s60;
+        ep1 = s11 * 12.f + s21 * 4.f - s31 * 15.f - s41 * 5.f + s51 * 3.f + s61;
+        ep2 = s12 * 12.f + s22 * 4.f - s32 * 15.f - s42 * 5.f + s52 * 3.f + s62;
+        VecType::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 4.f - s10 * 12.f + s30 * 15.f - s40 * 5.f - s50 * 3.f + s60;
+        ep1 = s21 * 4.f - s11 * 12.f + s31 * 15.f - s41 * 5.f - s51 * 3.f + s61;
+        ep2 = s22 * 4.f - s12 * 12.f + s32 * 15.f - s42 * 5.f - s52 * 3.f + s62;
+        VecType::save(dstPtr + 6 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 6 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 6 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 * 49.f - s10 * 36.f - s50 * 14.f + s70;
+        ep1 = s31 * 49.f - s11 * 36.f - s51 * 14.f + s71;
+        ep2 = s32 * 49.f - s12 * 36.f - s52 * 14.f + s72;
+        VecType::save(dstPtr + 7 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 7 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 7 * dstStep + 2 * packCUnit, ep2);
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
+}
+
+static void _sourceTransformUnit6x6Pack12(ElementType* srcBlock, ElementType* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/packCUnit
+    constexpr int Nh = 6; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    ElementType* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // register number : ePack
+        VecType::transpose12(srcPtr, packCUnit);
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    ElementType* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+        VecType s00 = VecType::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        VecType s01 = VecType::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        VecType s02 = VecType::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s10 = VecType::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        VecType s11 = VecType::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        VecType s12 = VecType::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s20 = VecType::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        VecType s21 = VecType::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        VecType s22 = VecType::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s30 = VecType::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        VecType s31 = VecType::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        VecType s32 = VecType::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s40 = VecType::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        VecType s41 = VecType::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        VecType s42 = VecType::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        VecType s50 = VecType::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        VecType s51 = VecType::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        VecType s52 = VecType::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        // to-try: reorder
+        auto ep0 = s00 * 4.f - s20 * 5.f + s40;
+        auto ep1 = s01 * 4.f - s21 * 5.f + s41;
+        auto ep2 = s02 * 4.f - s22 * 5.f + s42;
+        VecType::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * (-4.f) + s30 + s40;
+        ep1 = (s11 + s21) * (-4.f) + s31 + s41;
+        ep2 = (s12 + s22) * (-4.f) + s32 + s42;
+        VecType::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 - s20) * (4.f) + s40 - s30;
+        ep1 = (s11 - s21) * (4.f) + s41 - s31;
+        ep2 = (s12 - s22) * (4.f) + s42 - s32;
+        VecType::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * (-2.f) - s20 + s30 * 2.f + s40;
+        ep1 = s11 * (-2.f) - s21 + s31 * 2.f + s41;
+        ep2 = s12 * (-2.f) - s22 + s32 * 2.f + s42;
+        VecType::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 2.f - s20 - s30 * 2.f + s40;
+        ep1 = s11 * 2.f - s21 - s31 * 2.f + s41;
+        ep2 = s12 * 2.f - s22 - s32 * 2.f + s42;
+        VecType::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 4.f - s30 * 5.f + s50;
+        ep1 = s11 * 4.f - s31 * 5.f + s51;
+        ep2 = s12 * 4.f - s32 * 5.f + s52;
+        VecType::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        VecType::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        VecType::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
+}
+
 static void _sourceTransformUnit4x4(const int16_t* srcBlock, int16_t* dstStart, size_t srcStep, size_t dstStep) {
     BFVec4 s0 = BFVec4::load(srcBlock + 0 * srcStep);
     BFVec4 s1 = BFVec4::load(srcBlock + 1 * srcStep);
@@ -30,6 +361,7 @@ static void _sourceTransformUnit4x4(const int16_t* srcBlock, int16_t* dstStart,
     BFVec4::save(dstStart + 2 * dstStep, m2);
     BFVec4::save(dstStart + 3 * dstStep, m3);
 }
+
 static void _destTransformUnit4x2(const int16_t* srcBlock, int16_t* dstStart, size_t srcStep, size_t dstStep) {
     BFVec4 s0 = BFVec4::load(srcBlock + 0 * srcStep);
     BFVec4 s1 = BFVec4::load(srcBlock + 1 * srcStep);
@@ -168,6 +500,26 @@ static WinogradFunctionHalf::TransformFunc gProcUnit6[] = {
     _destTransformUnit6x5,
 };
 
+WinogradFunctionHalf::TransformPackFunc WinogradFunctionHalf::chooseWinoSourceTransformPack(int k, int w, int ePack, int lPack, int packCUnit) {
+
+    if (ePack == 12 && lPack == 1 && packCUnit == 4) {
+        if (k == 4 && w == 4) {
+            return _sourceTransformUnit4x4Pack12;
+        }
+        if (k == 6 && w == 6) {
+            return _sourceTransformUnit6x6Pack12;
+        }
+        if (k == 8 && w == 8) {
+            return _sourceTransformUnit8x8Pack12;
+        }
+        // other packing size
+    }
+    // if (ePack == 3 && lPack == 8 && packCUnit == 4)  no need to transformPack for x86 bf16 pack format of 3 x 8 x 4, will not be called in ConvolutionWinograd.cpp by allow_x86_bf16_winograd
+    MNN_ERROR("WinogradFunctionHalf Can not find function for ePack:%d, packCUnit:%d\n", ePack, packCUnit);
+    MNN_ASSERT(false);
+    return nullptr;
+}
+
 
 WinogradFunctionHalf::TransformFunc WinogradFunctionHalf::chooseSourceTransform(int k, int w) {
     if (6 == k && 6 == w) {
@@ -197,3 +549,7 @@ WinogradFunctionHalf::TransformFunc WinogradFunctionHalf::chooseDestTransform(in
 }
 
 } // namespace MNN
+
+#undef TRANSPOSE_12X4_SAVE
+
+

+ 2 - 1
source/backend/cpu/bf16/WinogradOptFunctionHalf.hpp

@@ -16,10 +16,11 @@ namespace MNN {
 class WinogradFunctionHalf {
 public:
     typedef void (*TransformFunc)(const int16_t* srcBlock, int16_t* dstStart, size_t srcStep, size_t dstStep);
-
+    typedef void (*TransformPackFunc)(int16_t* srcBlock, int16_t* dstStart, size_t dstStep);
     /*Use the generator with interp 0.5*/
     static TransformFunc chooseSourceTransform(int k, int w);
     static TransformFunc chooseDestTransform(int k, int h);
+    static TransformPackFunc chooseWinoSourceTransformPack(int k, int h, int ePack, int lPack, int packCUnit);
 };
 } // namespace MNN
 

+ 53 - 25
source/backend/cpu/compute/CommonOptFunction.cpp

@@ -9,6 +9,7 @@
 #include "CommonOptFunction.h"
 #include "ConvOpt.h"
 #include "WinogradOptFunction.hpp"
+#include "Int8FunctionsOpt.h"
 #include <string.h>
 #include <algorithm>
 #include <cmath>
@@ -16,8 +17,8 @@
 #include "math/Vec.hpp"
 #include <vector>
 #include "../CPURuntime.hpp"
-#include "core/MemoryFormater.h"
-#include "core/OpCommonUtils.hpp"
+#include "common/MemoryFormater.h"
+#include "common/CommonCompute.hpp"
 // TODO: Find better way to optimize it
 #include "../CPUBinary.hpp"
 #include "../CPUUnary.hpp"
@@ -106,8 +107,8 @@ void MNNUnpackC4Common(T* dst, const T* src, size_t area, size_t depth, int* are
     transpose: if false, export compressed matrix as h x l, other export as l x h.
  */
 void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose) {
-    // 1. in convolution, source B layout is OC x (IC * KH * KW),
-    //    the dest layout of weight is BCSC(block compressed sparse colum) format, which is OC(!=0) x (IC*KH*KW!=0), as a canceled result, just do BCSR, transpose should be false.
+    // 1. in convolution, source B layout is OC x (KH * KW * IC),
+    //    the dest layout of weight is BCSC(block compressed sparse colum) format, which is OC(!=0) x (KH*KW*IC!=0), as a canceled result, just do BCSR, transpose should be false.
     // 2. in ordinary sparse MatMul, transpose is corresponding to BCSR or BCSC
 
     // BCSR
@@ -116,7 +117,7 @@ void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffset
         for (int i = 0; i < l; i += 1) {
             *NNZMap = 0;
             for(int j = 0; j < h; j += sparseBlockOC) {
-                if(!MNN::OpCommonUtils::checkAllZeros(source + j * l + i, l, sparseBlockOC, 1)) {
+                if(!MNN::CommonCompute::checkAllZeros(source + j * l + i, l, sparseBlockOC, 1)) {
                     *dest = *(source + j * l + l);
                     dest++;
                     *NNZMap = *NNZMap + 1;
@@ -135,7 +136,7 @@ void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffset
         for (; i + sparseBlockOC <= h; i += sparseBlockOC) {
             *NNZMap = 0;
             for(int j = 0; j < l; j += 1) {
-                if (!MNN::OpCommonUtils::checkAllZeros(source, l, sparseBlockOC, 1)) {
+                if (!MNN::CommonCompute::checkAllZeros(source, l, sparseBlockOC, 1)) {
                     for (int ioc = 0; ioc < sparseBlockOC; ioc++) {
                         *dest = *(source + ioc * l);
                         dest++;
@@ -1364,12 +1365,12 @@ void MNNUnpackC4(float* dst, const float* src, size_t area, size_t depth, int* a
     MNNUnpackC4Common<float>(dst, src, area, depth, areaOffset);
 }
 
-void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8) {
+void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) {
     auto count = countC8 * 8;
     auto param = parameters[0];
     float xLimit = 87;
     for (int i = 0; i < count; ++i) {
-        auto x         = -source[i];
+        auto x         = source[i] * offset[0];
         x = ALIMAX(x, -xLimit);
         x = ALIMIN(x, xLimit);
         int div        = (x * parameters[1]);
@@ -1380,7 +1381,7 @@ void MNNExpC8(float* dest, const float* source, const float* parameters, size_t
         auto expRemain =
             ((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + parameters[3]) * t +
             parameters[2];
-        dest[i] = expBasic * expRemain;
+        dest[i] = expBasic * expRemain + offset[1];
     }
 }
 
@@ -1859,23 +1860,25 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i
     }
 }
 
-void MNNExp(float* dst, const float* src, size_t dataSize) {
+void MNNExp(float* dst, const float* src, const float* offset, size_t dataSize) {
     int countC8        = (int)dataSize / 8;
     if (countC8 > 0) {
         // Align to eight so asm is easier to write
-        static float parameters[] = {
-            (float)log(2.0f), 1.0f / (float)log(2.0f), 1.0f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
-        MNNExpC8(dst, src, parameters, countC8);
+        float parameters[] = {
+            (float)logf(2.0f), 1.0f / (float)logf(2.0f), 1.0f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
+        MNNExpC8(dst, src, offset, parameters, countC8);
     }
+    float alpha = offset[0];
+    float beta = offset[1];
     int remain = countC8 * 8;
-    auto param = log(2.0f);
+    auto param = logf(2.0f);
     float xLimit = 87;
     for (int i = remain; i < dataSize; i++) {
         /*Origin Function*/
-        //dst[i] = expf(-src[i]);
+        //dst[i] = expf(src[i] * alpha) + beta;
         /*Approciate Function*/
 
-        auto x         = -src[i];
+        auto x         = alpha * src[i];
         x = ALIMAX(x, -xLimit);
         x = ALIMIN(x, xLimit);
 
@@ -1886,7 +1889,7 @@ void MNNExp(float* dst, const float* src, size_t dataSize) {
 
         auto t         = xReamin;
         auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
-        dst[i]  = expBasic * expRemain;
+        dst[i]  = expBasic * expRemain + beta;
     }
 }
 
@@ -1912,10 +1915,11 @@ void MNNTanh(float* dst, const float* src, size_t dataSize) {
         dst[i] = tanhf_poly(src[i]);
     }
      */
-    for (int i = 0; i < dataSize; ++i) {
-        dst[i] = src[i] + src[i];
-    }
-    MNNExp(dst, dst, dataSize);
+    float offset[2] = {
+        -2.0f,
+        0.0f
+    };
+    MNNExp(dst, src, offset, dataSize);
     for (int i = 0; i < dataSize; i++) {
         // outputData[i] = 1 - 2 / (expf(2 * inputData[i]) + 1);
         auto expX2 = dst[i];
@@ -1979,6 +1983,12 @@ void MNNHardSwishCommon(float* dst, const float* src, size_t size) {
     }
 }
 
+void MNNGeluStandardCommon(float* dst, const float* src, size_t size) {
+    for (int i = 0; i < size; i++) {
+        dst[i] = (erf(src[i] * 0.7071067932881648) + 1) * src[i] * 0.5;
+    }
+}
+
 void MNNGeluCommon(float* dst, const float* src, size_t size) {
     int sizeQuad = size / 8;
     int start = 0;
@@ -2335,7 +2345,11 @@ void MNNSin(float* dst, const float* src, size_t dataSize) {
 }
 
 void MNNSigmoid(float* dst, const float* src, size_t dataSize) {
-    MNNExp(dst, src, dataSize);
+    float offset[2] = {
+       -1.0f,
+        0.0f
+    };
+    MNNExp(dst, src, offset, dataSize);
     for (int i = 0; i < dataSize; ++i) {
         dst[i] = 1.0f / (1.0f + dst[i]);
     }
@@ -2346,7 +2360,11 @@ void MNNSigmoid(float* dst, const float* src, size_t dataSize) {
  Thanks for https://github.com/hroken
  */
 void MNNSigmoidLowp(float* dst, const float* src, size_t dataSize) {
-    MNNExp(dst, src, dataSize);
+    float offset[2] = {
+       -1.0f,
+        0.0f
+    };
+    MNNExp(dst, src, offset, dataSize);
 #ifdef MNN_USE_NEON
     int dataC4 = (int)dataSize / 4;
     if(dataC4 > 0) {
@@ -2595,12 +2613,20 @@ void MNNCoreFunctionInit() {
 
     // Packed Function
     gCoreFunction->pack = 4;
+    // FIXME: MNNPackTranspose and MNNUnpackTranspose is reverted
     gCoreFunction->MNNPackCUnit = MNNPackC4;
     gCoreFunction->MNNUnpackCUnit = MNNUnpackC4;
-
-    // FIXME: MNNPackTranspose and MNNUnpackTranspose is reverted
     gCoreFunction->MNNUnpackCUnitTranspose = MNNPackTranspose;
     gCoreFunction->MNNPackCUnitTranspose = MNNUnpackTranspose;
+    gCoreFunction->MNNPackCUnitInt8 = decltype(gCoreFunction->MNNPackCUnitInt8)(MNNPackC4Uint8);
+    gCoreFunction->MNNUnpackCUnitInt8 = decltype(gCoreFunction->MNNUnpackCUnitInt8)(MNNUnpackC4Uint8);
+    gCoreFunction->MNNPackCUnitTransposeInt8 = decltype(gCoreFunction->MNNPackCUnitTransposeInt8)(MNNUnpackTransposeUint8);
+    gCoreFunction->MNNUnpackCUnitTransposeInt8 = decltype(gCoreFunction->MNNUnpackCUnitTransposeInt8)(MNNPackTransposeUint8);
+    gCoreFunction->MNNPackCUnitInt16 = MNNPackC4Int16;
+    gCoreFunction->MNNUnpackCUnitInt16 = MNNUnpackC4Int16;
+    gCoreFunction->MNNPackCUnitTransposeInt16 = MNNUnpackTransposeInt16;
+    gCoreFunction->MNNUnpackCUnitTransposeInt16 = MNNPackTransposeInt16;
+
     gCoreFunction->MNNAxByClampBroadcastUnit = MNNAxByClampBroadcastUnit;
     gCoreFunction->MNNConvRunForLineDepthwise = MNNConvRunForLineDepthwise;
     gCoreFunction->MNNConvRunForUnitDepthWise = MNNConvRunForUnitDepthWise;
@@ -2618,6 +2644,7 @@ void MNNCoreFunctionInit() {
     gCoreFunction->MNNCopyC4WithStride = MNNCopyC4WithStride;
 
     gCoreFunction->chooseWinoSourceTransform = WinogradFunction::chooseSourceTransform;
+    gCoreFunction->chooseWinoSourceTransformPack = WinogradFunction::chooseWinoSourceTransformPack;
     gCoreFunction->chooseWinoDestTransform = WinogradFunction::chooseDestTransform;
     gCoreFunction->MNNDeconvRunForLineDepthwise = MNNDeconvRunForLineDepthwise;
     gCoreFunction->MNNDeconvRunForUnitDepthWise = MNNDeconvRunForUnitDepthWise;
@@ -2634,6 +2661,7 @@ void MNNCoreFunctionInit() {
     gCoreFunction->supportFp16arith = gCPUInfo.fp16arith;
     gCoreFunction->supportSDot = gCPUInfo.dot;
 #endif
+    MNNCoreInt8FunctionInit();
     MNNFunctionInit();
 }
 CoreFunctions* MNNGetCoreFunctions() {

+ 28 - 3
source/backend/cpu/compute/CommonOptFunction.h

@@ -63,10 +63,10 @@ void MNNUInt8ToInt16WithOffsetC4Fast(int16_t* dst, const uint8_t* src, size_t ze
                                      size_t depthQuad, size_t dstZStep, size_t srcZStep);
 void MNNMaxFloat(float* input, float* maxBuffer, int32_t inputCountUnit);
 void MNNMinFloat(float* input, float* maxBuffer, int32_t inputCountUnit);
-void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8);
 void MNNPowC8(float* dest, const float* source, const float* powfParam, size_t betaInt, size_t countC8);
 
-void MNNExp(float* dst, const float* src, size_t dataSize);
+void MNNExpC8(float* dest, const float* source, const float* parameters, const float* offset, size_t countC8);
+void MNNExp(float* dst, const float* src, const float* offset, size_t dataSize);
 void MNNSin(float* dst, const float* src, size_t dataSize);
 void MNNTanh(float* dst, const float* src, size_t dataSize);
 void MNNSigmoid(float* dst, const float* src, size_t dataSize);
@@ -74,6 +74,7 @@ void MNNSigmoidLowp(float* dst, const float* src, size_t dataSize);
 void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slope);
 void MNNHardSwishCommon(float* dst, const float* src, size_t size);
 void MNNGeluCommon(float* dst, const float* src, size_t size);
+void MNNGeluStandardCommon(float* dst, const float* src, size_t size);
 void MNNSoftmax(float* dest, const float* source, size_t size);
 void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size);
 
@@ -104,6 +105,14 @@ void MNNFunctionInit();
 void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
 
 void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose);
+struct SparseMatMulParas
+{
+    float* C;
+    const float* A;
+    const float* B;
+    unsigned int* NNZMap;
+    int* dataOffsetMap;
+};
 void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap);
 
 void MNNPackedSparseMatMulEpx4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap);
@@ -145,6 +154,7 @@ void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* c
 
 typedef void(*MNNBinaryExecute)(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex);
 typedef void(*MNNUnaryExecute)(void* outputRaw, const void* inputRaw, int elementSize);
+typedef void(*MNNCopyWithStride)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds);
 
 namespace MNN {
 struct CoreFunctions {
@@ -180,11 +190,25 @@ struct CoreFunctions {
 
     /**NC4HW4's Functions*/
     int pack;
+    // For pack * bytes > 16
+    MNNCopyWithStride(*MNNSelectBlitFunction)(int blitBytes) = nullptr;
+
+    void(*MNNPackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNUnpackCUnitInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNPackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNUnpackCUnitTransposeInt16)(int16_t* dst, const int16_t* src, size_t area, size_t depth, int* areaOffset);
+
+    void(*MNNPackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNUnpackCUnitInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNPackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset);
+    void(*MNNUnpackCUnitTransposeInt8)(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset);
+
     void(*MNNPackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
     void(*MNNUnpackCUnit)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
     void(*MNNPackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
     void(*MNNUnpackCUnitTranspose)(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 
+    // NC4HW4's compute function
     void(*MNNConvRunForUnitDepthWise)(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
                                         size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
     void(*MNNConvRunForLineDepthwise)(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
@@ -208,7 +232,9 @@ struct CoreFunctions {
     void(*MNNAddC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
 
     typedef void (*WinoTransFunc)(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep);
+    typedef void (*WinoTransPackFunc)(float* srcBlock, float* dstStart, size_t dstStep);
     WinoTransFunc(*chooseWinoSourceTransform)(int k, int w);
+    WinoTransPackFunc(*chooseWinoSourceTransformPack)(int k, int w, int ePack, int lPack, int packCUnit);
     WinoTransFunc(*chooseWinoDestTransform)(int k, int h);
 
     void(*MNNDeconvRunForUnitDepthWise)(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
@@ -222,7 +248,6 @@ struct CoreFunctions {
     void(*MNNPoolingMax)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput,
                            int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth,
                            int strideHeight, int padWidth, int padHeight, int padType, int countType);
-
 };
 void MNNCoreFunctionInit();
 CoreFunctions* MNNGetCoreFunctions();

+ 115 - 65
source/backend/cpu/compute/ConvInt8TiledExecutor.cpp

@@ -15,13 +15,82 @@
 #include "core/Concurrency.h"
 #include "core/TensorUtils.hpp"
 #include <math.h>
-#ifdef MNN_USE_SSE
-extern "C" {
-void MNNInt8ToUInt8(void* ptr, int count);
-}
-#endif
 namespace MNN {
 
+ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res): CPUConvolution(convOp->common(), backend), mResource(res) {
+}
+
+ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, bool fastgemm)
+: CPUConvolution(common, backend) {
+}
+
+ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const ConvInt8TiledExecutor& exe)
+    : CPUConvolution(common, backend),
+    mDoPostProcess(exe.mDoPostProcess),
+    mResource(exe.mResource) {
+
+}
+
+ConvInt8TiledExecutor::~ConvInt8TiledExecutor() {
+    // Do nothing
+}
+
+bool ConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) {
+    return false;
+}
+
+ErrorCode ConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+    if (mDoPostProcess) {
+        mResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0]));
+    } else {
+        mResource->mInputZeroPoint = 0;
+    }
+    CPUConvolution::onResize(inputs, outputs);
+    auto input  = inputs[0];
+    auto output = outputs[0];
+
+    int UNIT = static_cast<CPUBackend*>(backend())->functions()->pack;
+    auto convCommon = mCommon;
+    const auto kernelCount = convCommon->kernelX() * convCommon->kernelY();
+    const auto srcCountUnit = UP_DIV(input->channel(), UNIT);
+
+    mIm2ColParamter.dilateX         = convCommon->dilateX();
+    mIm2ColParamter.dilateY         = convCommon->dilateY();
+    mIm2ColParamter.strideX         = convCommon->strideX();
+    mIm2ColParamter.strideY         = convCommon->strideY();
+    mIm2ColParamter.padX            = convCommon->padX();
+    mIm2ColParamter.padY            = convCommon->padY();
+    mIm2ColParamter.icDiv4          = srcCountUnit;
+    mIm2ColParamter.kernelX         = convCommon->kernelX();
+    mIm2ColParamter.kernelY         = convCommon->kernelY();
+    mIm2ColParamter.padX = mPadX;
+    mIm2ColParamter.padY = mPadY;
+
+    mIm2ColParamter.ih = input->height();
+    mIm2ColParamter.iw = input->width();
+    mIm2ColParamter.oh = output->height();
+    mIm2ColParamter.ow = output->width();
+    mIm2ColParamter.srcZStep = input->stride(1) * UNIT * input->batch();
+    mIm2ColParamter.srcYStep = input->stride(2) * UNIT;
+    mIm2ColParamter.packCUnit = UNIT;
+
+    int SRC_UNIT, DynamicDestUnit;
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+    getPackParameter(&UNIT, &SRC_UNIT, &DynamicDestUnit, core);
+    mTileCount        = UP_DIV(output->height() * output->width(), DynamicDestUnit);
+    const int threads = std::max(static_cast<CPUBackend*>(backend())->threadNumber(), 1);
+    mThreadNums       = std::min(threads, mTileCount);
+    return NO_ERROR;
+}
+
+//
+//  DenseConvInt8TiledExecutor.cpp
+//  MNN
+//
+//  Created by MNN on 2019/5/17.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
 static bool reorderWeight(Backend* bn, const Convolution2DCommon* common,
                           const std::shared_ptr<Tensor>& weightOrigin,
                           std::shared_ptr<Tensor>& weight) {
@@ -31,9 +100,9 @@ static bool reorderWeight(Backend* bn, const Convolution2DCommon* common,
     // reorder weight, [oc, ic, k^2] => [oc/unit, ((ic/unit)*k^2)/(src_unit/unit), unit(oc), (src_unit/unit), unit(ic)]
     int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY();
     std::vector<int> shape = {UP_DIV(oc, UNIT), UP_DIV(UP_DIV(ic, UNIT) * kernelCount, SRC_UNIT / UNIT), UNIT, SRC_UNIT};
-    
+
     weight.reset(Tensor::createDevice<int8_t>(shape));
-    
+
     bool succ = bn->onAcquireBuffer(weight.get(), Backend::STATIC);
     if (!succ) {
         MNN_ERROR("Memory not enough");
@@ -50,7 +119,7 @@ static bool reorderWeight(Backend* bn, const Convolution2DCommon* common,
             const int yIndex      = yOutSide + k * UP_DIV(ic, UNIT);
             const int ySubOutSide = yIndex / (SRC_UNIT / UNIT);
             const int ySubInSide  = yIndex % (SRC_UNIT / UNIT);
-            
+
             auto dstY       = weightDst + ySubOutSide * weight->stride(1) + ySubInSide * UNIT + yInSide;
             const auto srcY = srcK + y * kernelCount;
             for (int x = 0; x < oc; ++x) {
@@ -65,10 +134,11 @@ static bool reorderWeight(Backend* bn, const Convolution2DCommon* common,
     return true;
 }
 
-ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res): CPUConvolution(convOp->common(), backend), mResource(res) {
+DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res) : ConvInt8TiledExecutor(backend, convOp, res) {
     std::shared_ptr<Tensor> weightOrigin;
     weightOrigin.swap(mResource->mWeightInt8);
     mValid = reorderWeight(backend, convOp->common(), weightOrigin, mResource->mWeightInt8);
+    backend->onReleaseBuffer(weightOrigin.get(), Backend::STATIC);
     if(!mValid) {
         return;
     }
@@ -87,8 +157,8 @@ ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution
 #endif
 }
 
-ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, bool fastgemm)
-: CPUConvolution(common, backend) {
+DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, bool fastgemm)
+    : ConvInt8TiledExecutor(backend, common, weight, fastgemm) {
     auto core = static_cast<CPUBackend*>(backend)->int8Functions();
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
@@ -124,21 +194,20 @@ ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution
     mDoPostProcess = false;
 }
 
-ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const ConvInt8TiledExecutor& exe)
-    : CPUConvolution(common, backend), mGemmKernel(exe.mGemmKernel),
-    mDoPostProcess(exe.mDoPostProcess), mResource(exe.mResource) {
-    
+DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const DenseConvInt8TiledExecutor& exe)
+    : ConvInt8TiledExecutor(backend, common, exe), mGemmKernel(exe.mGemmKernel) {
+
 }
 
-ConvInt8TiledExecutor::~ConvInt8TiledExecutor() {
+DenseConvInt8TiledExecutor::~DenseConvInt8TiledExecutor() {
     // Do nothing
 }
 
-bool ConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) {
+bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) {
     if (nullptr == dst) {
         return true;
     }
-    auto exe = new ConvInt8TiledExecutor(bn, op->main_as_Convolution2D()->common(), *this);
+    auto exe = new DenseConvInt8TiledExecutor(bn, op->main_as_Convolution2D()->common(), *this);
     if (!exe->valid()) {
         return false;
     }
@@ -146,47 +215,22 @@ bool ConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst)
     return true;
 }
 
-ErrorCode ConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
-    if (mDoPostProcess) {
-        mResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0]));
-    } else {
-        mResource->mInputZeroPoint = 0;
-    }
-    CPUConvolution::onResize(inputs, outputs);
-    auto input  = inputs[0];
-    auto output = outputs[0];
-    
-    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
-    int UNIT, SRC_UNIT, DST_XUNIT;
-    core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    auto convCommon = mCommon;
-    const auto kernelCount = convCommon->kernelX() * convCommon->kernelY();
-    const auto srcCountUnit = UP_DIV(input->channel(), UNIT);
-    const auto totalKernelCountD8Div2 = UP_DIV(srcCountUnit * kernelCount, SRC_UNIT / UNIT);
+void DenseConvInt8TiledExecutor::getPackParameter(int* Unit, int* srcUnit, int* DestUnit, const CoreInt8Functions* core) {
+    core->MNNGetGemmUnit(Unit, srcUnit, DestUnit);
+}
 
-    mIm2ColParamter.dilateX         = convCommon->dilateX();
-    mIm2ColParamter.dilateY         = convCommon->dilateY();
-    mIm2ColParamter.strideX         = convCommon->strideX();
-    mIm2ColParamter.strideY         = convCommon->strideY();
-    mIm2ColParamter.padX            = convCommon->padX();
-    mIm2ColParamter.padY            = convCommon->padY();
-    mIm2ColParamter.icDiv4          = srcCountUnit;
-    mIm2ColParamter.kernelX         = convCommon->kernelX();
-    mIm2ColParamter.kernelY         = convCommon->kernelY();
-    mIm2ColParamter.kernelCountUnit = totalKernelCountD8Div2;
-    mIm2ColParamter.padX = mPadX;
-    mIm2ColParamter.padY = mPadY;
 
-    mIm2ColParamter.ih = input->height();
-    mIm2ColParamter.iw = input->width();
-    mIm2ColParamter.oh = output->height();
-    mIm2ColParamter.ow = output->width();
-    mIm2ColParamter.srcZStep = input->stride(1) * UNIT * input->batch();
-    mIm2ColParamter.srcYStep = input->stride(2) * UNIT;
+ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+    // Timer kernelTimer;
+    ConvInt8TiledExecutor::onResize(inputs, outputs);
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
 
-    mTileCount        = UP_DIV(output->height() * output->width(), DST_XUNIT);
-    const int threads = std::max(static_cast<CPUBackend*>(backend())->threadNumber(), 1);
-    mThreadNums       = std::min(threads, mTileCount);
+    int UNIT, SRC_UNIT, DST_XUNIT;
+    getPackParameter(&UNIT, &SRC_UNIT, &DST_XUNIT, core);
+    auto input  = inputs[0];
+    const auto kernelCount = mCommon->kernelX() * mCommon->kernelY();
+    const auto srcCountUnit = UP_DIV(input->channel(), UNIT);
+    mIm2ColParamter.kernelCountUnit = UP_DIV(srcCountUnit * kernelCount, SRC_UNIT / UNIT);
 
     // set im2col tensor info
     mTempIm2ColBuffer.reset(Tensor::createDevice<int8_t>({mThreadNums, DST_XUNIT, mResource->mWeightInt8->length(1) * SRC_UNIT}));
@@ -195,17 +239,19 @@ ErrorCode ConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, co
         return OUT_OF_MEMORY;
     }
     backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC);
+    // MNN_PRINT("dense conv2d int8 resize: cost time: %llu us\n", kernelTimer.durationInUs());
     return NO_ERROR;
 }
 
-ErrorCode ConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+    // Timer kernelTimer;
     const auto input = inputs[0];
     auto output      = outputs[0];
     auto core = static_cast<CPUBackend*>(backend())->int8Functions();
-    
+
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    
+
     auto im2ColProcess = core->chooseIm2Col(&mIm2ColParamter, input->channel());
 
     const int outputPlaneLen = output->height() * output->width();
@@ -220,7 +266,7 @@ ErrorCode ConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, c
 
     const auto inputDataPtr = input->host<int8_t>();
     const auto weightDataPtr = mResource->mWeightInt8->host<int8_t>();
-    
+
     auto im2colPtr           = mTempIm2ColBuffer->host<int8_t>();
     auto outputDataPtr       = output->host<int8_t>();
     QuanPostTreatParameters quanParam;
@@ -237,7 +283,7 @@ ErrorCode ConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, c
         quanParam.scale = nullptr;
     }
     //MNN_PRINT("max: %d, min: %d\n", quanParam.maxValue, quanParam.minValue);
-    
+
     const int bytes = (mDoPostProcess ? 1 : 4); // int8_t or float
 
     auto threadFunction = [&](int tId) {
@@ -250,10 +296,10 @@ ErrorCode ConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, c
                 const int xIndexStart  = tIndex * DST_XUNIT;
                 const int realDstCount = ALIMIN(outputPlaneLen - xIndexStart, DST_XUNIT);
                 // im2col
-                im2ColProcess(colAddr, srcPtr, mResource->mInputZeroPoint, &mIm2ColParamter, xIndexStart, realDstCount);
 #ifdef MNN_USE_SSE
-                const int col_buffer_size = mIm2ColParamter.kernelCountUnit * DST_XUNIT * SRC_UNIT;
-                MNNInt8ToUInt8(colAddr, col_buffer_size);
+                im2ColProcess(colAddr, srcPtr, mResource->mInputZeroPoint + 128, &mIm2ColParamter, xIndexStart, realDstCount);
+#else
+                im2ColProcess(colAddr, srcPtr, mResource->mInputZeroPoint, &mIm2ColParamter, xIndexStart, realDstCount);
 #endif
                 auto outputInTilePtr = dstPtr + xIndexStart * UNIT * bytes;
                 mGemmKernel(outputInTilePtr, colAddr, weightDataPtr, kernelCountUnitDouble, dstZStep * bytes, ocDiv4, &quanParam, realDstCount);
@@ -264,8 +310,12 @@ ErrorCode ConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, c
         threadFunction((int)tId);
     }
     MNN_CONCURRENCY_END();
-
+    // MNN_PRINT("dense conv2d int8 execute: cost time: %llu us\n", kernelTimer.durationInUs());
     return NO_ERROR;
 }
 
+
+
+
+
 } // namespace MNN

+ 33 - 4
source/backend/cpu/compute/ConvInt8TiledExecutor.hpp

@@ -23,19 +23,48 @@ public:
     ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, bool fastgemm);
     virtual ~ConvInt8TiledExecutor();
     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
-    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
     virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
-private:
+    virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0;
+protected:
     ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const ConvInt8TiledExecutor& exe);
+    friend class ConvInt8Winograd;
+
+protected:
     ConvolutionCommon::Im2ColParameter mIm2ColParamter;
     int mTileCount;
     int mThreadNums;
     std::shared_ptr<Tensor> mTempIm2ColBuffer;
-    decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel;
     bool mDoPostProcess = true; //whether quan post process (add bias, min/max then scale to int8)
     std::shared_ptr<CPUConvolution::ResourceInt8> mResource;
-    
+
+};
+
+//
+//  DenseConvInt8TiledExecutor.hpp
+//  MNN
+//
+//  Created by MNN on 2019/5/17.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+
+class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor {
+public:
+    // given weight+bias+scale, do post process
+    DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res);
+    // only given weight, not do post process
+    DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, bool fastgemm);
+    virtual ~DenseConvInt8TiledExecutor();
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
+    void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
+private:
+    DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const DenseConvInt8TiledExecutor& exe);
     friend class ConvInt8Winograd;
+
+    decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel;
+
 };
 
 } // namespace MNN

+ 16 - 16
source/backend/cpu/compute/ConvInt8Winograd.cpp

@@ -31,7 +31,7 @@ bool ConvInt8Winograd::chooseTransformFuncs(int kernelY, int kernelX, int unitY,
     auto core = static_cast<CPUBackend*>(bn)->int8Functions();
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    
+
     int alphaY = kernelY + unitY - 1, alphaX = kernelX + unitX - 1;
     WinoSrcTransFunc srcFuncY = nullptr, srcFuncX = nullptr;
     WinoDstTransFunc dstFuncY = nullptr, dstFuncX = nullptr;
@@ -128,7 +128,7 @@ ConvInt8Winograd::ConvInt8Winograd(Backend *b, const Convolution2D *convOp, std:
 #else
             bool fastgemm = (convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE);
 #endif
-            exe.reset(new ConvInt8TiledExecutor(b, subCommon->first, weight, fastgemm));
+            exe.reset(new DenseConvInt8TiledExecutor(b, subCommon->first, weight, fastgemm));
         } else {
             bool fastgemm = false;
 #ifdef MNN_USE_SSE
@@ -173,11 +173,11 @@ bool ConvInt8Winograd::onClone(Backend* bn, const Op* op, Execution** dst) {
 ErrorCode ConvInt8Winograd::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     mResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0]));
     CPUConvolution::onResize(inputs, outputs);
-    
+
     auto core = static_cast<CPUBackend*>(backend())->int8Functions();
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    
+
     auto input = inputs[0], output = outputs[0];
     int batch = input->batch(), ic = input->channel(), oc = output->channel();
     int ih = input->height(), iw = input->width();
@@ -197,11 +197,11 @@ ErrorCode ConvInt8Winograd::onResize(const std::vector<Tensor *> &inputs, const
         }
         unit.common = createCommon(unit.common->first, {}, {ALIMAX(mPadY - unit.attr.kyStart, 0), ALIMAX(mPadX - unit.attr.kxStart, 0)});
         if (unit.attr.unitY == 1 && unit.attr.unitX == 1) {
-            static_cast<ConvInt8TiledExecutor*>(unit.runner.get())->mCommon = unit.common->first;
+            static_cast<DenseConvInt8TiledExecutor*>(unit.runner.get())->mCommon = unit.common->first;
         } else {
             static_cast<WinoExecution*>(unit.runner.get())->mCommon = unit.common->first;
         }
-        
+
         auto res = unit.runner->onResize({unit.input.get()}, {unit.output.get()});
         if (res != NO_ERROR) {
             mValid = false;
@@ -265,7 +265,7 @@ ConvInt8Winograd::WinoExecution::WinoExecution(Backend *bn, const Convolution2DC
     if (fastgemm) {
         mGemmKernel = core->Int8GemmKernelFast;
     }
-    
+
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
 
@@ -280,7 +280,7 @@ ConvInt8Winograd::WinoExecution::WinoExecution(Backend *bn, const Convolution2DC
 
     chooseTransformFuncs(mKernelY, mKernelX, mUnitY, mUnitX, this, bn);
     WinogradInt8Helper helper(mUnitY, mUnitX, common, core);
-    
+
     mWeight = helper.allocTransformWeight(weight);
     mOffsets.reset(Tensor::createDevice<int32_t>({alpha2, oc4 * UNIT}));
     mValid = backend()->onAcquireBuffer(mWeight.get(), Backend::STATIC);
@@ -312,7 +312,7 @@ ConvInt8Winograd::WinoExecution::WinoExecution(Backend* bn, const Convolution2DC
     mDestTransformY(exe.mDestTransformY), mDestTransformX(exe.mDestTransformX),
     mUnitY(exe.mUnitY), mUnitX(exe.mUnitX), mKernelY(exe.mKernelY), mKernelX(exe.mKernelX),
     mGemmKernel(exe.mGemmKernel), mInputZeroPoint(exe.mInputZeroPoint) {
-    
+
     mTempInputBuffer.reset(Tensor::createDevice<int8_t>(exe.mTempInputBuffer->shape()));
     mTempOutputBuffer.reset(Tensor::createDevice<float>(exe.mTempOutputBuffer->shape()));
     mTransformMidBuffer.reset(Tensor::createDevice<int8_t>(exe.mTransformMidBuffer->shape()));
@@ -348,13 +348,13 @@ ErrorCode ConvInt8Winograd::WinoExecution::onResize(const std::vector<Tensor *>
 }
 ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     auto core = static_cast<CPUBackend*>(backend())->int8Functions();
-    
+
     auto input = inputs[0], output = outputs[0];
     int alphaY = mKernelY + mUnitY - 1, alphaX = mKernelX + mUnitX - 1, alpha2 = alphaY * alphaX;
     bool conv1d = (alphaY == 1 || alphaX == 1);
     int UNIT, SRC_UNIT, DST_XUNIT;
     core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    
+
     int ow = output->width(), oh = output->height();
     int iw = input->width(), ih = input->height();
     int ic = input->channel(), ic_4 = UP_DIV(ic, UNIT);
@@ -368,7 +368,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
     int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1);
     int tileCount    = UP_DIV(totalCount, DST_XUNIT);
     threadNumber     = std::min(threadNumber, tileCount);
-    
+
     auto srcOrigin = input->host<int8_t>();
     auto dstOrigin = output->host<float>();
 
@@ -406,7 +406,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
                         int ex    = ALIMIN(srcX + alphaX, iw) - srcX;
                         int count = UNIT * (ex - sx);
                         auto dst_x = dstS + si * SRC_UNIT;
-                        
+
                         int sourceZStep = iw * ih * input->batch() * UNIT;
                         int sourceYStep = iw * UNIT;
                         auto srcStart = srcOrigin + srcY * sourceYStep + srcX * UNIT + bIndex * iw * ih * UNIT;
@@ -430,7 +430,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
                             sourceZStep = alpha2 * UNIT;
                             sourceYStep = alphaX * UNIT;
                         }
-                        
+
                         if (!conv1d) {
                             for (int i = 0; i < alphaY; ++i) {
                                 mSourceTransformX(srcStart + i * sourceYStep, midBuffer1 + i * SRC_UNIT,
@@ -489,7 +489,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
                         auto dstStart = dstOrigin + (dstX + dstY * ow + bIndex * ow * oh) * UNIT;
                         int ex = ALIMIN(dstX + mUnitX, ow) - dstX;
                         int count = ex * UNIT;
-                        
+
                         auto _dstStart = dstStart;
                         int dstZStep = oh * ow * output->batch() * UNIT, dstYStep = ow * UNIT;
                         if (ex != mUnitX || (alphaX == 1 && ey != mUnitY)) {
@@ -558,7 +558,7 @@ bool ConvInt8Winograd::bestWinogradUnit(const Convolution2D *convOp, const Tenso
     }
     int kernelY = common->kernelY(), kernelX = common->kernelX();
     int oh = output->height(), ow = output->width(), oc = common->outputCount(), ic = common->inputCount();
-    
+
     const int CONV_WINOGRAD_MAX_KERNEL = 3, CONV_WINOGRAD_ALPHA = 4;
     using Vec = std::vector<std::pair<int, int>>;
     auto partitionKernelFunc = [=](int kernel, bool range = false) -> Vec {

+ 10 - 2
source/backend/cpu/compute/ConvolutionFloatFactory.cpp

@@ -33,14 +33,22 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
 #endif
 
 #ifdef MNN_USE_SPARSE_COMPUTE
+#ifndef MNN_AVX512 // Currently AVX512 don't support sparse
     auto core = static_cast<CPUBackend*>(backend)->functions();
     int bytes = core->bytes;
-    if (bytes == 4 && core->pack == 4 && conv2d->sparseParameter()) {
+#ifdef MNN_USE_SSE
+    const bool onlySSENotAVX = core->pack == 4; // no backend of only sse without avx2 or avx512
+#else
+    const bool onlySSENotAVX = false;
+#endif
+    if (!onlySSENotAVX && bytes == 4 && conv2d->sparseParameter()) {
         if (SparseConvolutionTiledExecutor::shouldUseSparseConvolution(originWeightSize, conv2d->sparseParameter())) {
-            return new SparseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, conv2d->sparseParameter(), bias, biasSize);
+            return new SparseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize,
+                                                      conv2d->sparseParameter(), bias, biasSize);
         }
     }
 #endif
+#endif
     bool fastWay = common->kernelY() == 1 && common->kernelX() == 1
         && output->width() == input->width() && output->height() == input->height()
         && common->strideX() == 1 && common->strideY() == 1;

+ 1 - 1
source/backend/cpu/compute/ConvolutionTiledExecutor.cpp

@@ -16,7 +16,7 @@
 #include "core/TensorUtils.hpp"
 #include "math/Vec.hpp"
 #include "core/BufferAllocator.hpp"
-#include "core/MemoryFormater.h"
+#include "common/MemoryFormater.h"
 
 using Vec4 = MNN::Math::Vec<float, 4>;
 namespace MNN {

+ 291 - 100
source/backend/cpu/compute/ConvolutionWinograd.cpp

@@ -14,11 +14,16 @@
 #include "core/Macro.h"
 #include "core/TensorUtils.hpp"
 #include "math/WingoradGenerater.hpp"
+#include <MNN/AutoTime.hpp>
+#include "common/MemoryFormater.h"
 #ifdef MNN_USE_NEON
 #include <arm_neon.h>
 #endif
 #define CONVOLUTION_WINOGRAD_MAX_UNIT 8
 #define CONVOLUTION_WINOGRAD_MIN_UNIT 2
+constexpr int FULSE_THRESHHOLD_NUMERATOR = 8;
+constexpr int FULSE_THRESHHOLD_DENOMINATOR = 10;
+
 using namespace MNN::Math;
 
 //#define MNN_WINOGRAD_PRINT_REDUCE_RATE
@@ -44,25 +49,29 @@ ConvolutionWinograd::ConvolutionWinograd(const Convolution2DCommon *convOp, cons
     auto kernelSize = mCommon->kernelY();
     WinogradGenerater generator(unit, kernelSize, 1, true);
 
+    int ePack, hPack, lPack;
+    core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
+
     int alpha        = unit + kernelSize - 1;
     int alpha2       = alpha * alpha;
     mSourceTransform = core->chooseWinoSourceTransform(alpha, alpha);
     mDestTransform   = core->chooseWinoDestTransform(alpha, unit);
-
+    mSourceTransformPack = core->chooseWinoSourceTransformPack(alpha, alpha, ePack, lPack, pack);
     int srcCount                       = input->channel();
     int outputCount                    = output->channel();
     auto ic4 = UP_DIV(srcCount, pack);
     auto oc4 = UP_DIV(outputCount, pack);
-    int ePack, hPack, lPack;
-    core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
-
     mTempBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, ePack, ic4 + oc4, pack * alpha2, bytes}));
-    mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, 2, alpha2, pack, bytes}));
-    mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, ePack * UP_DIV(srcCount, lPack) * lPack, bytes}));
+    // mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, 2, alpha2, pack, bytes}));
+    // mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, ePack * UP_DIV(srcCount, lPack) * lPack, bytes}));
+
+    mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, (1 + ic4 * ePack), alpha2, pack, bytes})); // 1 means original small buffer of alpha2 * pack.
+    mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, alpha, ePack * UP_DIV(srcCount, pack) * pack, bytes}));
+
 
     mA = generator.A();
     mB = generator.B();
-    
+
 
     // Transform Kernel
     auto G = generator.G();
@@ -70,7 +79,7 @@ ConvolutionWinograd::ConvolutionWinograd(const Convolution2DCommon *convOp, cons
     std::shared_ptr<Tensor> sourceWeight(Tensor::create<float>(
         std::vector<int>{outputCount, srcCount, kernelSize, kernelSize}, (void *)originWeight, Tensor::CAFFE));
     auto tempWeight = generator.allocTransformWeight(sourceWeight.get(), lPack, hPack, true);
-    
+
     auto shape = tempWeight->shape();
     shape.push_back(bytes);
     mResource->mWeight.reset(Tensor::createDevice<uint8_t>(shape));
@@ -105,6 +114,7 @@ bool ConvolutionWinograd::onClone(Backend* bn, const Op* op, Execution** dst) {
     dstExe->mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>(mGemmMidBuffer->shape()));
     dstExe->mSourceTransform = mSourceTransform;
     dstExe->mDestTransform = mDestTransform;
+    dstExe->mSourceTransformPack = mSourceTransformPack;
     dstExe->mPostParameters = mPostParameters;
     *dst = dstExe;
     return true;
@@ -113,15 +123,17 @@ bool ConvolutionWinograd::onClone(Backend* bn, const Op* op, Execution** dst) {
 ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     auto core = static_cast<CPUBackend*>(backend())->functions();
     int pack = core->pack, bytes = core->bytes;
-    
+
     auto input   = inputs[0];
     auto output  = outputs[0];
-    auto dstUnit = mA->length(1);
-    auto srcUnit = mA->length(0);
+    auto dstUnit = mA->length(1); // m
+    auto srcUnit = mA->length(0); // n
     int ePack, lPack, hPack;
     core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
 
     auto srcUnit2 = srcUnit * srcUnit;
+    auto alphaXStride = srcUnit * ePack * pack;
+    auto IC4alpha2Stride = srcUnit2 * ePack * pack;
 
     int ow   = output->width();
     int oh   = output->height();
@@ -135,8 +147,8 @@ ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, co
     int padY = mPadY;
     int padX = mPadX;
 
-    auto wUnit = UP_DIV(ow, dstUnit);
-    auto hUnit = UP_DIV(oh, dstUnit);
+    auto wUnit = UP_DIV(ow, dstUnit); // ow / m
+    auto hUnit = UP_DIV(oh, dstUnit); // oh / m
 
     auto totalCount   = wUnit * hUnit * batch;
     // MNN_PRINT("ow=%d, oh=%d\n", ow, oh);
@@ -159,132 +171,302 @@ ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, co
     auto outputOrigin = output->host<uint8_t>();
     auto srcOrigin = inputOrigin;
     auto dstOrigin = outputOrigin;
+    auto midBuffer0Bytes = srcUnit2 * pack * bytes;
+
+    bool allow_x86_bf16_winograd = true;
+#ifdef MNN_USE_SSE
+    allow_x86_bf16_winograd = bytes != 2; // only bf16 has length of 2 byte on x86. fp16 dosnot exist.
+#endif
+
+    // using ElementType = int16_t;
+    // MNN_PRINT("winograd: this:%p, n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kh:%d, kw:%d, totalCount:%d, srcUnit:%d, dstUnit:%d, ePack:%d, pack:%d, bytes:%d\n",
+    //     this, batch, ih, iw, input->channel(), oh, ow, output->channel(), mCommon->kernelX(), mCommon->kernelY(), totalCount, srcUnit, dstUnit, ePack, pack, bytes);
+    // MNN_PRINT("origin data matrix:\n");
+    // formatMatrix((const ElementType*)srcOrigin, {ic_4, batch*ih, iw, pack});
 
     auto weight    = mResource->mWeight->host<uint8_t>();
     auto bias      = mResource->mBias->host<uint8_t>();
     auto tFunction = [&](int tId) {
         auto _srcOrigin = mTempBuffer->host<uint8_t>() + tId * mTempBuffer->stride(0);
-        auto gemmBuffer = (float*)(mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
+        auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
         auto midBuffer0 = mTransformMidBuffer->host<uint8_t>() + tId * mTransformMidBuffer->stride(0);
-        auto midBuffer1 = midBuffer0 + mTransformMidBuffer->stride(1);
+        auto midBuffer1 = midBuffer0 + midBuffer0Bytes;
         for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) {
             int xIndex  = (int)tIndex * ePack;
             int xReamin = totalCount - xIndex;
             int xC      = xReamin > ePack ? ePack : xReamin;
+            const bool fuseTransformPack = (xC * FULSE_THRESHHOLD_DENOMINATOR > FULSE_THRESHHOLD_NUMERATOR * ePack) && allow_x86_bf16_winograd;
+            // const bool fuseTransformPack = false;
+
+            // Timer timer;
+            // uint64_t durationSourceTrans1 = 0;
+            // uint64_t durationSourceTrans2 = 0;
+            // uint64_t durationMul = 0;
+            // uint64_t packATime = 0;
+            // uint64_t durationDestTrans1 = 0;
+            // uint64_t durationDestTrans2 = 0;
 
             /*Source Transform Begin*/
 #ifndef MNN_WINO_TRANFORM_TEST_CLOSE
             {
                 int sourceZStep = iw * ih * batch * pack;
-                int dstZStep    = xC * pack;
-                int unitStep    = ic_4 * xC * pack;
                 int oyBegin = xIndex / wUnit;
                 int oxBegin = xIndex % wUnit;
                 int oyEnd = (xIndex + xC-1) / wUnit;
                 int remain = xC;
-                auto dstS = _srcOrigin;
-                for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
-                    int hIndex = hbIndex % hUnit;
-                    int bIndex = hbIndex / hUnit;
-                    int step = std::min(wUnit - oxBegin, remain);
-                    int srcY  = hIndex * dstUnit - padY;
-                    int ey    = ALIMIN(srcY + srcUnit, ih) - srcY;
-                    int sy    = ALIMAX(0, srcY) - srcY;
-                    for (int si=0; si<step; ++si) {
-                        auto wIndex = si + oxBegin;
-                        int srcX  = wIndex * dstUnit - padX;
-                        int sx    = ALIMAX(0, srcX) - srcX;
-                        int ex    = ALIMIN(srcX + srcUnit, iw) - srcX;
-                        int count = pack * (ex - sx);
-                        auto dst_x = dstS + si * pack * bytes;
-                        auto srcStart = srcOrigin + (srcX + srcY * iw + bIndex * iw * ih) * pack * bytes;
-                        if (ex - sx == srcUnit && ey - sy == srcUnit) {
-                            for (int z = 0; z < ic_4; ++z) {
-                                auto srcZ = srcStart + z * sourceZStep * bytes;
-                                // Transform
-                                for (int i = 0; i < srcUnit; ++i) {
-                                    auto srcFloatPtr = (const float*)(srcZ + i * iw * pack * bytes);
-                                    auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
-                                    mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
+                int destSOffset = 0;
+                if (fuseTransformPack) {
+                    for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
+                        int hIndex = hbIndex % hUnit;
+                        int bIndex = hbIndex / hUnit;
+                        int step = std::min(wUnit - oxBegin, remain);
+                        int srcY  = hIndex * dstUnit - padY;
+                        int ey    = ALIMIN(srcY + srcUnit, ih) - srcY;
+                        int sy    = ALIMAX(0, srcY) - srcY;
+                        for (int si=0; si<step; ++si) {
+                            auto wIndex = si + oxBegin;
+                            int srcX  = wIndex * dstUnit - padX;
+                            int sx    = ALIMAX(0, srcX) - srcX;
+                            int ex    = ALIMIN(srcX + srcUnit, iw) - srcX;
+                            int count = pack * (ex - sx);
+
+                            auto srcStart = srcOrigin + (srcX + srcY * iw + bIndex * iw * ih) * pack * bytes;
+                            // MNN_PRINT("\nxIndex:%d, xC:%d, alphaXStride:%d, srcUnit:%d, destUnit:%d, hUnit:%d, wUnit:%d, srcY:%d, hStart:%d ,hEnd:%d, wStart:%d, wEnd:%d, i_oh:%d, i_ow:%d, srcOffset:%ld, destSOffset:%d\n",
+                            //     xIndex, xC, alphaXStride, srcUnit, dstUnit, hUnit, wUnit, srcY, sy, ey, sx, ey, hIndex - oyBegin, si, (srcStart - srcOrigin)/bytes, (destSOffset)/bytes);
+                            // timer.reset();
+                            auto midBuffer1Offset = midBuffer1 + destSOffset;
+
+                            if (ex - sx == srcUnit && ey - sy == srcUnit) {
+                                for (int z = 0; z < ic_4; ++z) {
+                                    auto srcZ = srcStart + z * sourceZStep * bytes;
+                                    // Transform
+                                    // MNN_PRINT("z:%d, srcOffset:%ld, destSOffset:%ld, \n", z, ((unsigned const char*)srcZ - srcOrigin)/bytes, ((unsigned const char*)midBuffer1Offset - midBuffer1)/bytes);
+                                    // MNN_PRINT("winograd source sub matrix:\n");
+                                    // formatMatrix((const float*)srcZ, {srcUnit, 4});
+                                    for (int i = 0; i < srcUnit; ++i) { // i_Nh
+                                        auto srcFloatPtr = (const float*)(srcZ + i * iw * pack * bytes);
+                                        auto dstFloatPtr = (float*)(midBuffer1Offset + i * ePack * pack * bytes);
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack, alphaXStride); // tranform srcUnit*4 elements in one time
+                                        // MNN_PRINT("z:%d, 1 stage i_Nh:%d th srcOffset:%ld, destOffset:%ld, \n", z, i, ((unsigned const char*)srcFloatPtr - srcOrigin)/bytes, ((unsigned const char*)dstFloatPtr - midBuffer1)/bytes);
+                                        // MNN_PRINT("winograd source sub matrix:\n");
+                                        // formatMatrix(srcFloatPtr, {srcUnit, 4});
+                                    }
+                                    midBuffer1Offset += IC4alpha2Stride * bytes;
                                 }
-                                auto dstZ = dst_x + z * dstZStep * bytes;
-                                for (int i = 0; i < srcUnit; ++i) {
-                                    auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
-                                    auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
-                                    mSourceTransform(srcFloatPtr, dstFloatPtr, pack,
-                                                     unitStep * srcUnit);
+                            } else {
+                                for (int z = 0; z < ic_4; ++z) {
+                                    // Extract
+                                    auto srcZ = srcStart + z * sourceZStep * bytes;
+                                    ::memset(midBuffer0, 0, midBuffer0Bytes);
+                                    if (count > 0) {
+                                        for (int yy = sy; yy < ey; ++yy) {
+                                            auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
+                                            auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
+                                            ::memcpy(dst_yy, src_yy, count * bytes);
+                                        }
+                                    }
+                                    // Transform
+                                    for (int i = 0; i < srcUnit; ++i) {
+                                        auto srcFloatPtr = (const float*)(midBuffer0 + i * srcUnit * pack * bytes);
+                                        auto dstFloatPtr = (float*)(midBuffer1Offset + i * ePack * pack * bytes);
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack, alphaXStride);
+                                    }
+                                    midBuffer1Offset += IC4alpha2Stride * bytes;
                                 }
                             }
-                        } else {
-                            for (int z = 0; z < ic_4; ++z) {
-                                // Extract
-                                auto srcZ = srcStart + z * sourceZStep * bytes;
-                                ::memset(midBuffer0, 0, mTransformMidBuffer->stride(1));
-                                if (count > 0) {
-                                    for (int yy = sy; yy < ey; ++yy) {
-                                        auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
-                                        auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
-                                        ::memcpy(dst_yy, src_yy, count * bytes);
+                            // durationSourceTrans1 += timer.durationInUs();
+
+                            destSOffset += pack * bytes;
+                        }
+                        oxBegin = 0;
+                        remain -= step;
+                    }
+                } else {
+                    int dstZStep    = xC * pack;  // hUnit*wUnit * 4
+                    int unitStep    = ic_4 * xC * pack; //  C/4 * hUnit*wUnit * 4
+                    for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
+                        int hIndex = hbIndex % hUnit;
+                        int bIndex = hbIndex / hUnit;
+                        int step = std::min(wUnit - oxBegin, remain);
+                        int srcY  = hIndex * dstUnit - padY;
+                        int ey    = ALIMIN(srcY + srcUnit, ih) - srcY; //h dim pack element length
+                        int sy    = ALIMAX(0, srcY) - srcY;  // first y element
+                        for (int si=0; si<step; ++si) {
+                            auto wIndex = si + oxBegin;
+                            int srcX  = wIndex * dstUnit - padX;
+                            int sx    = ALIMAX(0, srcX) - srcX;
+                            int ex    = ALIMIN(srcX + srcUnit, iw) - srcX;
+                            int count = pack * (ex - sx);
+
+                            auto srcStart = srcOrigin + (srcX + srcY * iw + bIndex * iw * ih) * pack * bytes;
+                            // MNN_PRINT("\nxIndex:%d, xC:%d, alphaXStride:%d, srcUnit:%d, destUnit:%d, hUnit:%d, wUnit:%d, srcY:%d, hStart:%d ,hEnd:%d, wStart:%d, wEnd:%d, i_oh:%d, i_ow:%d, srcOffset:%ld, destSOffset:%d\n",
+                            //     xIndex, xC, alphaXStride, srcUnit, dstUnit, hUnit, wUnit, srcY, sy, ey, sx, ey, hIndex - oyBegin, si, (srcStart - srcOrigin)/bytes, (destSOffset)/bytes);
+                            // timer.reset();
+                            auto dst_x = _srcOrigin + destSOffset;
+                            if (ex - sx == srcUnit && ey - sy == srcUnit) {
+                                for (int z = 0; z < ic_4; ++z) {
+                                    auto srcZ = srcStart + z * sourceZStep * bytes;
+                                    // Transform
+
+                                    for (int i = 0; i < srcUnit; ++i) {
+                                        auto srcFloatPtr = (const float*)(srcZ + i * iw * pack * bytes);
+                                        auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
+                                        // MNN_PRINT("z:%d, 1 stage i_Nh:%d th srcOffset:%ld, destOffset:%ld, \n", z, i, ((unsigned const char*)srcFloatPtr - srcOrigin)/bytes, ((unsigned const char*)dstFloatPtr - midBuffer1)/bytes);
+                                        // MNN_PRINT("winograd source sub matrix:\n");
+                                        // formatMatrix(srcFloatPtr, {srcUnit, pack});
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
+
+                                    }
+                                    auto dstZ = dst_x + z * dstZStep * bytes;
+                                    for (int i = 0; i < srcUnit; ++i) {
+                                        auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
+                                        auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack,
+                                                         unitStep * srcUnit);
                                     }
                                 }
-                                // Transform
-                                for (int i = 0; i < srcUnit; ++i) {
-                                    auto srcFloatPtr = (const float*)(midBuffer0 + i * srcUnit * pack * bytes);
-                                    auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
-                                    mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
-                                }
-                                auto dstZ = dst_x + z * dstZStep * bytes;
-                                for (int i = 0; i < srcUnit; ++i) {
-                                    auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
-                                    auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
-                                    mSourceTransform(srcFloatPtr, dstFloatPtr, pack, unitStep * srcUnit);
+                            } else {
+                                for (int z = 0; z < ic_4; ++z) {
+                                    // Extract
+                                    auto srcZ = srcStart + z * sourceZStep * bytes;
+                                    ::memset(midBuffer0, 0, mTransformMidBuffer->stride(1));
+                                    if (count > 0) {
+                                        for (int yy = sy; yy < ey; ++yy) {
+                                            auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
+                                            auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
+                                            ::memcpy(dst_yy, src_yy, count * bytes);
+                                        }
+                                    }
+
+                                    // Transform
+                                    for (int i = 0; i < srcUnit; ++i) {
+                                        auto srcFloatPtr = (const float*)(midBuffer0 + i * srcUnit * pack * bytes);
+                                        auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
+                                    }
+                                    auto dstZ = dst_x + z * dstZStep * bytes;
+                                    for (int i = 0; i < srcUnit; ++i) {
+                                        auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
+                                        auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
+                                        mSourceTransform(srcFloatPtr, dstFloatPtr, pack, unitStep * srcUnit);
+                                    }
                                 }
                             }
+                            // durationSourceTrans1 += timer.durationInUs();
+                            destSOffset += pack * bytes;
                         }
+                        oxBegin = 0;
+                        remain -= step;
                     }
-                    oxBegin = 0;
-                    remain -= step;
-                    dstS += pack * step * bytes;
                 }
             }
-            /*Source Transform End*/
+
 #endif
-            // Multi
-            auto _dstOrigin = _srcOrigin + xC * srcUnit2 * ic_4 * pack * bytes;
-
-            int32_t info[4];
-            info[0] = 1;
-            info[1] = xC;
-            info[2] = xC;
-            info[3] = 1;
-            int32_t el[4];
-            el[0] = xC;
-            el[1] = parameters[1];
-            el[2] = 0;
-            el[3] = 0;
-            if (xC == ePack) {
-                for (int i = 0; i < srcUnit2; ++i) {
-                    auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
-                    auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
-                    auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
-                    core->MNNPackC4ForMatMul_A(gemmBuffer, &srcTemp, info, el);
-                    core->MNNPackedMatMul(_dstFloatPtr, gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr);
+            auto* _dstOrigin = _srcOrigin;
+            if (fuseTransformPack) {
+                _dstOrigin += ePack * srcUnit2 * ic_4 * pack * bytes;
+                if (xC != ePack) {
+                    auto midTransformPtr = midBuffer1 + xC * pack * bytes;
+                    for (int i = 0; i < ic_4 * srcUnit2; ++i) {
+                        memset(midTransformPtr, 0, (ePack - xC) * pack * bytes);
+                        midTransformPtr += ePack * pack * bytes;
+                    }
+                }
+                // MNN_PRINT("winograd source matrix transform 1 D*B:\n");
+                // formatMatrix((const ElementType*)midBuffer1, {ic_4, srcUnit, srcUnit, ePack, pack});
+                for (int iNw = 0; iNw < srcUnit; ++iNw) { // i_Nw
+                    // timer.reset();
+                    auto midTransformPtr = midBuffer1 + iNw * alphaXStride * bytes;
+                    auto unitsGemmbuffer = gemmBuffer;
+                    for (int z = 0; z < ic_4; ++z) { // ic_4
+                        mSourceTransformPack((float*)midTransformPtr, (float*)unitsGemmbuffer, ePack * pack * ic_4);
+                        unitsGemmbuffer += ePack * pack * bytes;
+                        midTransformPtr += IC4alpha2Stride * bytes;
+                    }
+
+                    // durationSourceTrans2 += timer.durationInUs();
+                    // timer.reset();
+                    // MNN_PRINT("winograd source matrix transform 2 BT*D*B, iNw:%d\n", iNw);
+                    // formatMatrix((const ElementType*)gemmBuffer, {srcUnit, ic_4 * pack, ePack});
+
+                    // Previous tranform requires xC aligned with EPack, xC should be Epack;
+                    for (int iNh = 0; iNh < srcUnit; ++iNh) { // i_Nh, gemm
+                        auto unitsGemmbuffer = gemmBuffer + iNh * ic_4 * pack * ePack * bytes;
+                        auto _dstFloatPtr = (float*)(_dstOrigin + (iNh * srcUnit + iNw) * dc_4 * pack * ePack * bytes);
+                        auto _weightFloatPtr = (const float*)(weight + (iNh * srcUnit + iNw) * mResource->mWeight->stride(0));
+                        core->MNNPackedMatMul(_dstFloatPtr, (float*)unitsGemmbuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr);
+                        // MNN_PRINT("winograd MatMul result, iNh%d, iNw:%d\n", iNh, iNw);
+                        // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, ePack, pack});
+                    }
+                    // durationMul += timer.durationInUs();
                 }
             } else {
-                for (int i = 0; i < srcUnit2; ++i) {
-                    auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
-                    auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
-                    auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
-                    core->MNNPackC4ForMatMul_A(gemmBuffer, &srcTemp, info, el);
-                    core->MNNPackedMatMulRemain(_dstFloatPtr, gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr);
+                // MNN_PRINT("winograd source matrix after b*d*b:\n");
+                // formatMatrix((const ElementType*)_srcOrigin, {srcUnit, srcUnit, ic_4, hUnit, wUnit, pack});
+                /*Source Transform End*/
+                // // Multi
+                _dstOrigin += xC * srcUnit2 * ic_4 * pack * bytes;
+
+                int32_t info[4];
+                info[0] = 1;
+                info[1] = xC;
+                info[2] = xC;
+                info[3] = 1;
+                int32_t el[4];
+                el[0] = xC;
+                el[1] = parameters[1];
+                el[2] = 0;
+                el[3] = 0;
+                if (xC == ePack) {
+                    for (int i = 0; i < srcUnit2; ++i) {
+                        // timer.reset();
+
+                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
+                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
+                        auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
+                        // MNN_PRINT("winograd i_n:%d, xC:%d, ePack:%d, before packA:\n", i, xC, ePack);
+                        // formatMatrix((const ElementType*)srcTemp, {ic_4, xC, pack});
+
+                        core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);
+
+                        // packATime += timer.durationInUs();
+                        // timer.reset();
+                        // MNN_PRINT("winograd i_n:%d, after packA:\n", i);
+                        // formatMatrix((const ElementType*)gemmBuffer, {1, ic_4 * pack, xC});
+                        core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr);
+                        // MNN_PRINT("winograd MatMul result, iNh:%d, iNw:%d\n", i/srcUnit, i % srcUnit);
+                        // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, xC, pack});
+                        // durationMul += timer.durationInUs();
+                    }
+                } else {
+                    for (int i = 0; i < srcUnit2; ++i) {
+                        // timer.reset();
+                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
+                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
+                        auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
+                        // MNN_PRINT("winograd i_n:%d, xC:%d, ePack:%d, before packA:\n", i, xC, ePack);
+                        // formatMatrix((const ElementType*)srcTemp, {ic_4, xC, pack});
+
+                        core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);
+                        // packATime += timer.durationInUs();
+                        // timer.reset();
+                        // MNN_PRINT("winograd i_n:%d, after packA:\n", i);
+                        // formatMatrix((const ElementType*)gemmBuffer, {1, ic_4 * pack, xC});
+                        core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr);
+                        // MNN_PRINT("winograd MatMul result, iNh:%d, iNw:%d\n", i/srcUnit, i % srcUnit);
+                        // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, xC, pack});
+                        // durationMul += timer.durationInUs();
+                    }
                 }
             }
+
 #ifndef MNN_WINO_TRANFORM_TEST_CLOSE
             /* Dest Transform And Post Treat Begin */
             {
+
+                int srcZStep = (fuseTransformPack ? ePack : xC) * pack;
+                int unitStep = (fuseTransformPack ? ePack : xC) * dc_4 * pack;
                 int dstZStep = ow * oh * pack * batch;
-                int srcZStep = xC * pack;
-                int unitStep = dc_4 * xC * pack;
                 int oyBegin = xIndex / wUnit;
                 int oxBegin = xIndex % wUnit;
                 int oyEnd = (xIndex + xC-1) / wUnit;
@@ -350,6 +532,15 @@ ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, co
             }
 #endif
             /*Dest Transform And Post Treat End*/
+            // if (fuseTransformPack) {
+            //     MNN_PRINT(
+            //         "\n relayout fused:\n\tdurationSourceTrans1: %lu us \n\tdurationSourceTrans2: %lu us \n\tdurationMul: %lu us\n\ttotal: %lu us\n",
+            //         durationSourceTrans1, durationSourceTrans2, durationMul, durationSourceTrans1 + durationSourceTrans2 + durationMul);
+            // } else {
+            //     MNN_PRINT(
+            //         "\n origin:\n\tdurationSourceTrans1+2: %lu us \n\t packA:%lu us \n\t durationMul:%lu us\n\ttotal: %lu us\n",
+            //         durationSourceTrans1, packATime, durationMul, durationSourceTrans1 + durationSourceTrans2 + durationMul + packATime);
+            // }
         }
     };
 

+ 1 - 0
source/backend/cpu/compute/ConvolutionWinograd.hpp

@@ -41,6 +41,7 @@ private:
 
     CoreFunctions::WinoTransFunc mSourceTransform;
     CoreFunctions::WinoTransFunc mDestTransform;
+    CoreFunctions::WinoTransPackFunc mSourceTransformPack;
     std::vector<float> mPostParameters;
 };
 } // namespace MNN

+ 3 - 3
source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp

@@ -16,7 +16,7 @@
 #include "core/TensorUtils.hpp"
 #include "math/Vec.hpp"
 #include "core/BufferAllocator.hpp"
-#include "core/MemoryFormater.h"
+#include "common/MemoryFormater.h"
 
 using Vec4 = MNN::Math::Vec<float, 4>;
 namespace MNN {
@@ -219,7 +219,7 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
     TensorUtils::setLinearLayout(&mTempBufferTranspose);
     auto plane    = width * height * batch;
     int tileCount = UP_DIV(plane, eP);
-                                              
+
     bool success = backend()->onAcquireBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
     if (!success) {
         return OUT_OF_MEMORY;
@@ -243,7 +243,7 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
         auto srcPtr     = (float const **)((uint8_t *)tempPtr.first + tempPtr.second +
                                        tId * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
         auto el         = (int32_t *)(srcPtr + kernelSize * maxLine);
-                                        
+
         int32_t info[4];
         info[1] = src_width * src_height * batch;
         info[2] = eP;

File diff suppressed because it is too large
+ 1407 - 11
source/backend/cpu/compute/Int8FunctionsOpt.cpp


+ 23 - 6
source/backend/cpu/compute/Int8FunctionsOpt.h

@@ -24,10 +24,10 @@ typedef SSIZE_T ssize_t;
 #define GEMM_INT8_UNIT 4
 #define GEMM_INT8_SRC_UNIT 16
 #ifndef MNN_USE_SSE
-#ifdef __aarch64__
-#define GEMM_INT8_DST_XUNIT 4
-#else
-#define GEMM_INT8_DST_XUNIT 2
+    #ifdef __aarch64__
+    #define GEMM_INT8_DST_XUNIT 4
+    #else
+    #define GEMM_INT8_DST_XUNIT 2
 #endif
 #else
 #define GEMM_INT8_DST_XUNIT 4
@@ -36,6 +36,7 @@ typedef SSIZE_T ssize_t;
 #ifdef __cplusplus
 extern "C" {
 #endif
+
 struct QuanPostTreatParameters {
     const float* scale;
     const int32_t* bias;
@@ -48,6 +49,10 @@ void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float*
                    ssize_t maxValue, ssize_t zeroPoint);
 void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint);
 void MNNInt8FunctionInit();
+void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap);
+void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap);
+
+
 #ifdef __cplusplus
 }
 #endif
@@ -59,18 +64,30 @@ struct CoreInt8Functions {
     void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount);
     void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT);
     // Im2Col
-    typedef void(*Im2ColFunc)(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
+    typedef void(*Im2ColFunc)(int8_t* colAddr, const int8_t* inputOrigin, int32_t inputZeroPoint,
                               const ConvolutionCommon::Im2ColParameter* im2colParameter, size_t xIndexStart,
                               size_t realDstCount);
     Im2ColFunc(*chooseIm2Col)(const ConvolutionCommon::Im2ColParameter* im2colParam, size_t inputChannel);
+
+    // sparse
+    void(*MNNGetSparseQuantMatMulPackMode)(int* eP, int *lP, int* hP);
+    void(*MNNPackForSparseQuantMatMul_B)(int8_t* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const int8_t* source, size_t h, size_t kernelCount, size_t icCount, const int eP);
+    void(*MNNPackedSparseQuantMatMulEpx1)(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap);
+    void(*MNNPackedSparseQuantMatMulEpx4)(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap);
+    void(*MNNSparseQuantIm2col)(int8_t* colAddr, const int8_t* inputOrigin, int8_t inputZeroPoint,
+                              const ConvolutionCommon::Im2ColParameter* im2colParameter, const size_t* sparseQuantParam, size_t xIndexStart);
     // winograd
     using WinoSrcTransFunc = WinogradInt8Helper::SrcTransFunc;
     using WinoDstTransFunc = WinogradInt8Helper::DstTransFunc;
     WinoSrcTransFunc(*chooseWinoSourceTransform)(int alpha, int inPack, int outPack);
     WinoDstTransFunc(*chooseWinoDestTransform)(int alpha, int unit);
-    
+
     void(*ConvDepthwiseLineInt8)(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width,
                                  size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
+    void(*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
+                       ssize_t maxValue, ssize_t zeroPoint);
+    void(*MNNInt8ScaleToFloat)(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint);
+
 };
 void MNNCoreInt8FunctionInit();
 CoreInt8Functions* MNNGetInt8CoreFunctions();

+ 249 - 0
source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp

@@ -0,0 +1,249 @@
+//
+//  SparseConvInt8TiledExecutor.hpp
+//  MNN
+//
+//  Created by MNN on 2021/6/09.
+//  Copyright © 2018 - 2021, Alibaba Group Holding Limited
+
+
+#include "SparseConvInt8TiledExecutor.hpp"
+#include "core/Macro.h"
+
+#include <math.h>
+#include "backend/cpu/CPUBackend.hpp"
+#include "backend/cpu/compute/CommonOptFunction.h"
+#include "core/Concurrency.h"
+#include "core/TensorUtils.hpp"
+#include "common/MemoryFormater.h"
+#include "MNN/AutoTime.hpp"
+#include <math.h>
+#ifdef MNN_USE_SSE
+extern "C" {
+void MNNInt8ToUInt8(void* ptr, int count);
+}
+#endif
+namespace MNN {
+
+bool SparseConvInt8TiledExecutor::reorderWeight(Backend* b, const Convolution2DCommon* common,
+                          const std::shared_ptr<Tensor>& weightOrigin,
+                          std::shared_ptr<Tensor>& weight, const SparseCommon* sparseCommon) {
+
+    int eP, lP, hP;
+    auto core = static_cast<CPUBackend*>(b)->int8Functions();
+    core->MNNGetSparseQuantMatMulPackMode(&eP, &lP, &hP);
+
+    int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY();
+    auto sparseBlockOC = sparseCommon->args()->LookupByKey("sparseBlockOC")->i();
+    size_t weightNNZElement = sparseCommon->args()->LookupByKey("NNZElement")->i();
+    size_t weightBlockNumber = sparseCommon->args()->LookupByKey("blockNumber")->i();
+
+    // MNN_PRINT("1x%d weightNNZElement%zu, weightBlockNumber:%zu\n", sparseBlockOC, weightNNZElement, weightBlockNumber);
+    weight.reset(Tensor::createDevice<uint8_t>({ static_cast<int>(weightNNZElement + 1)}));   // one more element in case of weight are all zeros
+    mNNZMap.reset(Tensor::createDevice<unsigned int>({oc / sparseBlockOC + oc % sparseBlockOC}));
+    mDataOffsetMap.reset(Tensor::createDevice<int>({static_cast<int>(weightBlockNumber + 1)}));
+
+    mValid = backend()->onAcquireBuffer(weight.get(), Backend::STATIC);
+    mValid = mValid && backend()->onAcquireBuffer(mNNZMap.get(), Backend::STATIC);
+    mValid = mValid && backend()->onAcquireBuffer(mDataOffsetMap.get(), Backend::STATIC);
+    if(!mValid) {
+        MNN_PRINT("in: %s, out of memory!\n", __FUNCTION__);
+        return false;
+    }
+    // MNN_PRINT("oc:%d, sparseBlockOC:%d,\n", oc, sparseBlockOC);
+    core->MNNPackForSparseQuantMatMul_B(weight->host<int8_t>(), mNNZMap->host<unsigned int>(),
+                                       mDataOffsetMap->host<int>(), sparseBlockOC, weightOrigin->host<int8_t>(), oc, kernelCount, ic, eP);
+
+    // MNN_PRINT("\nBCSR int8 weight:");
+    // formatMatrix(weight->host<int8_t>(), {static_cast<int>(weightNNZElement)});
+    // MNN_PRINT("\nBCSR int8 weight nnzmap:");
+    // formatMatrix(mNNZMap->host<unsigned int>(), {oc / sparseBlockOC + oc % sparseBlockOC});
+    // MNN_PRINT("\nBCSR int8 weight dataOffsetMap:");
+    // formatMatrix(mDataOffsetMap->host<int>(), {static_cast<int>(weightBlockNumber + 1)});
+
+    return true;
+}
+
+SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res) : ConvInt8TiledExecutor(backend, convOp, res) {
+
+    std::shared_ptr<Tensor> weightOrigin;
+    weightOrigin.swap(mResource->mWeightInt8);
+    const SparseCommon* sparseCommon = convOp->sparseParameter();
+    mValid = reorderWeight(backend, convOp->common(), weightOrigin, mResource->mWeightInt8, sparseCommon);
+    if(!mValid) {
+        return;
+    }
+
+    // choose int8 sparse gemm kernel
+    auto sparseBlockOC = sparseCommon->args()->LookupByKey("sparseBlockOC")->i();
+    auto core = static_cast<CPUBackend*>(backend)->int8Functions();
+    mSparseQuantMatMulKernel = sparseBlockOC == 4 ? core->MNNPackedSparseQuantMatMulEpx4 : core->MNNPackedSparseQuantMatMulEpx1;
+
+}
+
+SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, const SparseCommon* sparseCommon, bool fastgemm)
+    : ConvInt8TiledExecutor(backend, common, weight, fastgemm) {
+
+    auto core = static_cast<CPUBackend*>(backend)->int8Functions();
+    int UNIT, SRC_UNIT, DST_XUNIT;
+    core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
+    int oc = common->outputCount(), ic = common->inputCount(), kernel = common->kernelY() * common->kernelX();
+    mResource.reset(new ResourceInt8);
+    mResource->backend = backend;
+    mResource->mBiasInt32.reset(Tensor::createDevice<int32_t>({ROUND_UP(oc, UNIT)}));
+    mValid = backend->onAcquireBuffer(mResource->mBiasInt32.get(), Backend::STATIC);
+    if (!mValid) {
+        MNN_ERROR("Memory not enough\n");
+        return;
+    }
+    ::memset(mResource->mBiasInt32->host<int32_t>(), 0, mResource->mBiasInt32->size());
+#ifdef MNN_USE_SSE
+    for (int oz = 0; oz < oc; ++oz) {
+        int32_t offset = 0;
+        for (int i = 0; i < ic * kernel; ++i) {
+            offset += (int32_t)(weight->host<int8_t>()[oz * ic * kernel + i]) * (-128);
+        }
+        mResource->mBiasInt32->host<int32_t>()[oz] = offset;
+    }
+#endif
+
+    mValid = reorderWeight(backend, common, weight, mResource->mWeightInt8, sparseCommon);
+    if(!mValid) {
+        MNN_ERROR("Memory not enough\n");
+        return;
+    }
+    // choose int8 gemm kernel
+    auto sparseBlockOC = sparseCommon->args()->LookupByKey("sparseBlockOC")->i();
+    mSparseQuantMatMulKernel = sparseBlockOC == 4 ? core->MNNPackedSparseQuantMatMulEpx4 : core->MNNPackedSparseQuantMatMulEpx1;
+    mDoPostProcess = false;
+}
+
+SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common,
+                                                         const SparseConvInt8TiledExecutor& exe)
+    : ConvInt8TiledExecutor(backend, common, exe),
+      mNNZMap(exe.mNNZMap),
+      mDataOffsetMap(exe.mDataOffsetMap),
+      mSparseBlockOC(exe.mSparseBlockOC),
+      mSparseQuantMatMulKernel(exe.mSparseQuantMatMulKernel) {
+}
+
+SparseConvInt8TiledExecutor::~SparseConvInt8TiledExecutor() {
+    // Do nothing
+}
+
+bool SparseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) {
+    if (nullptr == dst) {
+        return true;
+    }
+    auto exe = new SparseConvInt8TiledExecutor(bn, op->main_as_Convolution2D()->common(), *this);
+    if (!exe->valid()) {
+        return false;
+    }
+    *dst = exe;
+    return true;
+}
+
+void SparseConvInt8TiledExecutor::getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) {
+    core->MNNGetSparseQuantMatMulPackMode(DestUnit, Unit, SrcUnit);
+}
+
+ErrorCode SparseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+
+    // Timer kernelTimer;
+
+    ConvInt8TiledExecutor::onResize(inputs, outputs);
+
+    int eP, lP, hP;
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+    getPackParameter(&lP, &hP, &eP, core);
+    int lSize = mIm2ColParamter.icDiv4 * mIm2ColParamter.packCUnit * mCommon->kernelX() * mCommon->kernelY();
+
+    mIm2ColParamter.destICStride = mIm2ColParamter.icDiv4 * mIm2ColParamter.packCUnit * eP;
+
+    mSparseQuantParam.eP = eP;
+    mSparseQuantParam.l = lSize;
+    mSparseQuantParam.h = mCommon->outputCount();
+    mSparseQuantParam.aStride = eP * mSparseQuantParam.l;
+    mSparseQuantParam.cStride = outputs[0]->batch() * outputs[0]->height() * outputs[0]->width() * static_cast<CPUBackend*>(backend())->functions()->pack;
+
+    mTempIm2ColBuffer.reset(Tensor::createDevice<int8_t>({mThreadNums, eP, UP_DIV(lSize, lP) * lP}));
+    bool success = backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC);
+    if (!success) {
+        return OUT_OF_MEMORY;
+    }
+    backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC);
+
+    // MNN_PRINT("sparse conv2d int8 resize: cost time: %llu us\n", kernelTimer.durationInUs());
+    return NO_ERROR;
+}
+
+ErrorCode SparseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+    // Timer kernelTimer;
+    const auto input = inputs[0];
+    auto output      = outputs[0];
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+
+    int PackUnit = static_cast<CPUBackend*>(backend())->functions()->pack;
+    auto sparseQuantIm2col = core->MNNSparseQuantIm2col;
+    const int outputPlaneLen = output->height() * output->width();
+    const int inputPlaneLen = input->width() * input->height();
+
+    const int batch = input->batch();
+    const int ocDivPack = UP_DIV(output->channel(), PackUnit);
+
+    const auto inputDataPtr = input->host<int8_t>();
+    const auto weightDataPtr = mResource->mWeightInt8->host<int8_t>();
+    const auto NNZMapPtr     = mNNZMap->host<unsigned int>();
+    const auto dataOffsetPtr = mDataOffsetMap->host<int>();
+    auto im2colPtr           = mTempIm2ColBuffer->host<int8_t>();
+    auto outputDataPtr       = output->host<int8_t>();
+    QuanPostTreatParameters quanParam;
+    quanParam.bias = mResource->mBiasInt32->host<int32_t>();
+    if (mDoPostProcess) {
+        quanParam.scale = mResource->mScaleFloat->host<float>();
+        quanParam.maxValue = mResource->mClampMax;
+        if (mResource->mRelu) {
+            quanParam.minValue = mResource->mOutputZeroPoint;
+        } else {
+            quanParam.minValue = mResource->mClampMin;
+        }
+    } else {
+        quanParam.scale = nullptr;
+    }
+    // MNN_PRINT("outputPlaneLen: %d, reduce l:%zu, minValue:%d, maxValue:%d, mTileCount:%d\n", outputPlaneLen, mSparseQuantParam.l, quanParam.minValue, quanParam.maxValue, mTileCount);
+    const int bytes = (mDoPostProcess ? 1 : 4); // int8_t or float
+    auto threadFunction = [&](int tId) {
+        auto colAddr        = im2colPtr + tId * mTempIm2ColBuffer->stride(0);
+        for (int bIndex = 0; bIndex < batch; ++bIndex) {
+            const auto srcPtr = inputDataPtr + bIndex * PackUnit * bytes * inputPlaneLen;
+            auto dstPtr       = outputDataPtr + bIndex * PackUnit * bytes * outputPlaneLen;
+
+            for (int tIndex = tId; tIndex < mTileCount; tIndex += mThreadNums) {
+                SparseQuantMatMulParam sparseQuantParam = mSparseQuantParam;
+                const int xIndexStart  = tIndex * sparseQuantParam.eP;
+                const int realDstCount = ALIMIN(outputPlaneLen - xIndexStart, sparseQuantParam.eP);
+                sparseQuantParam.eSize = realDstCount;
+                // im2col
+                sparseQuantIm2col(colAddr, srcPtr, mResource->mInputZeroPoint, &mIm2ColParamter, (size_t*)&sparseQuantParam, xIndexStart);
+                // MNN_PRINT("batch:%d, realDstCount:%d, InputZeroPoint:%d, inputdata matrix im2col:\n", bIndex, realDstCount, mResource->mInputZeroPoint);
+                // formatMatrix(colAddr, {static_cast<int>(UP_DIV(realDstCount, sparseQuantParam.eP)), static_cast<int>(sparseQuantParam.l), static_cast<int>(sparseQuantParam.eP)});
+
+#ifdef MNN_USE_SSE
+                const int col_buffer_size = sparseQuantParam.aStride * sizeof(int8_t);
+                MNNInt8ToUInt8(colAddr, col_buffer_size);
+#endif
+                auto outputInTilePtr = dstPtr + xIndexStart * PackUnit * bytes;
+                // MNN_PRINT("bIndex:%d, offset:%zu, spmm sparseMatmul tile:\n", bIndex, outputInTilePtr - outputDataPtr);
+                mSparseQuantMatMulKernel(outputInTilePtr, colAddr, weightDataPtr, (size_t*)&sparseQuantParam, &quanParam, NNZMapPtr, dataOffsetPtr);
+                // formatMatrix(outputInTilePtr, {static_cast<int>(UP_DIV(sparseQuantParam.h, PackUnit)), realDstCount, PackUnit});
+            }
+        }
+    };
+    MNN_CONCURRENCY_BEGIN(tId, mThreadNums) {
+        threadFunction((int)tId);
+    }
+    MNN_CONCURRENCY_END();
+    // MNN_PRINT("sparse conv2d int8 execute: cost time: %llu us\n", kernelTimer.durationInUs());
+    return NO_ERROR;
+}
+
+} // namespace MNN

+ 70 - 0
source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp

@@ -0,0 +1,70 @@
+//
+//  SparseConvInt8TiledExecutor.hpp
+//  MNN
+//
+//  Created by MNN on 2021/6/09.
+//  Copyright © 2018 - 2021, Alibaba Group Holding Limited
+//
+
+
+#ifndef SparseConvInt8TiledExecutor_hpp
+#define SparseConvInt8TiledExecutor_hpp
+#include "ConvInt8TiledExecutor.hpp"
+#include "backend/cpu/CPUConvolution.hpp"
+#include "ConvInt8Winograd.hpp"
+#include "Int8FunctionsOpt.h"
+
+#define SPARSITY_THRESHOLD (0.2f)
+
+namespace MNN {
+
+
+struct SparseQuantMatMulParam {
+                    // only use size_t type
+    size_t eSize;   // left matrix length of real value
+    size_t eP;      // left matrix pack Unit
+    size_t aStride; // left matrix stride
+    size_t l;       // left matrix row, (kh * kw * ic/4 * 4)
+    size_t h;       // right matrix colum, (oc)
+    size_t cStride; // output matrix Stride on highest dim (ow * oh * C4Unit * bytes)
+};
+
+class SparseConvInt8TiledExecutor : public ConvInt8TiledExecutor {
+public:
+    // given weight+bias+scale, do post process
+    SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr<ResourceInt8> res);
+    // only given weight, not do post process
+    SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, std::shared_ptr<Tensor> weight, const SparseCommon* sparseCommon, bool fastgemm);
+    virtual ~SparseConvInt8TiledExecutor();
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
+
+    void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
+    bool reorderWeight(Backend* b, const Convolution2DCommon* common, const std::shared_ptr<Tensor>& weightOrigin,
+                       std::shared_ptr<Tensor>& weight, const SparseCommon* sparseCommon);
+
+    static bool shouldUseSparse(const Convolution2D* conv2d) {
+        auto common = conv2d->common();
+        size_t originWeightSize = common->outputCount() * common->inputCount() * common->kernelY() * common->kernelX();
+        const SparseCommon* sparseCommon = conv2d->sparseParameter();
+        // MNN_PRINT("SparseConvInt8TiledExecutor sparsity:%f\n", 1 - float(sparseCommon->args()->LookupByKey("NNZElement")->i())/originWeightSize);
+        return originWeightSize - sparseCommon->args()->LookupByKey("NNZElement")->i() >= originWeightSize * SPARSITY_THRESHOLD;
+    }
+
+private:
+    SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const SparseConvInt8TiledExecutor& exe);
+    friend class ConvInt8Winograd;
+
+    SparseQuantMatMulParam mSparseQuantParam;
+    decltype(CoreInt8Functions::MNNPackedSparseQuantMatMulEpx1) mSparseQuantMatMulKernel;
+    std::shared_ptr<Tensor> mNNZMap;
+    std::shared_ptr<Tensor> mDataOffsetMap;
+    int mSparseBlockOC;
+};
+
+} // namespace MNN
+
+#undef SPARSITY_THRESHOLD
+
+#endif /* SparseConvInt8TiledExecutor_hpp */

+ 5 - 4
source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp

@@ -16,7 +16,7 @@
 #include "core/TensorUtils.hpp"
 #include "math/Vec.hpp"
 #include "core/BufferAllocator.hpp"
-#include "core/MemoryFormater.h"
+#include "common/MemoryFormater.h"
 
 using Vec4 = MNN::Math::Vec<float, 4>;
 namespace MNN {
@@ -88,7 +88,8 @@ SparseConvolutionTiledExecutor::SparseConvolutionTiledExecutor(std::shared_ptr<C
     mProxy.reset(new SparseConvolutionTiledImpl(common, sparseCommon, b));
 }
 SparseConvolutionTiledExecutor::~SparseConvolutionTiledExecutor() {
-    // Do nothing
+    backend()->onReleaseBuffer(mNNZMap.get(), Backend::STATIC);
+    backend()->onReleaseBuffer(mDataOffsetMap.get(), Backend::STATIC);
 }
 bool SparseConvolutionTiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) {
 
@@ -169,7 +170,7 @@ ErrorCode SparseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& input
     TensorUtils::setLinearLayout(&mTempBufferTranspose);
     auto plane    = width * height * batch;
     int tileCount = UP_DIV(plane, eP);
-                                              
+
     bool success = backend()->onAcquireBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
     if (!success) {
         return OUT_OF_MEMORY;
@@ -193,7 +194,7 @@ ErrorCode SparseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& input
         auto srcPtr     = (float const **)((uint8_t *)tempPtr.first + tempPtr.second +
                                        tId * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
         auto el         = (int32_t *)(srcPtr + kernelSize * maxLine);
-                                        
+
         int32_t info[4];
         info[1] = src_width * src_height * batch;
         info[2] = eP;

+ 6 - 1
source/backend/cpu/compute/SparseConvolutionTiledExecutor.hpp

@@ -52,9 +52,12 @@ public:
                     float *cache, int depth, int outputCount, int kernelSize, int eP, size_t weightNNZElement,
                     size_t weightBlockNumber, const CoreFunctions *function);
 
-    static  bool shouldUseSparseConvolution(size_t originWeightSize, const SparseCommon* sparseCommon) {
+    static bool shouldUseSparseConvolution(size_t originWeightSize, const SparseCommon* sparseCommon) {
         return originWeightSize - sparseCommon->args()->LookupByKey("NNZElement")->i() >= originWeightSize * SPARSITY_THRESHOLD;
     }
+    static float getSparsityThreshold() {
+        return SPARSITY_THRESHOLD;
+    }
 protected:
     std::shared_ptr<SparseConvolutionTiledImpl> mProxy;
     std::shared_ptr<Tensor> mNNZMap;
@@ -62,4 +65,6 @@ protected:
 };
 } // namespace MNN
 
+#undef SPARSITY_THRESHOLD
+
 #endif /* SparseConvolutionTiledExecutor_hpp */

+ 14 - 12
source/backend/cpu/compute/WinogradInt8Helper.cpp

@@ -36,12 +36,14 @@ static inline void TRANS_4x4(VecType& vec0, VecType& vec1, VecType& vec2, VecTyp
     vec2.value = _mm_castps_si128(m2);
     vec3.value = _mm_castps_si128(m3);
 #else
-    auto m0 = vtrn1q_s32(vec0.value, vec1.value), m1 = vtrn2q_s32(vec0.value, vec1.value);
-    auto m2 = vtrn1q_s32(vec2.value, vec3.value), m3 = vtrn2q_s32(vec2.value, vec3.value);
-    vec0.value = vtrn1q_s64(m0, m2);
-    vec1.value = vtrn1q_s64(m1, m3);
-    vec2.value = vtrn2q_s64(m0, m2);
-    vec3.value = vtrn2q_s64(m1, m3);
+    auto m0 = vtrn1q_s32(reinterpret_cast<int32x4_t>(vec0.value), reinterpret_cast<int32x4_t>(vec1.value));
+    auto m1 = vtrn2q_s32(reinterpret_cast<int32x4_t>(vec0.value), reinterpret_cast<int32x4_t>(vec1.value));
+    auto m2 = vtrn1q_s32(reinterpret_cast<int32x4_t>(vec2.value), reinterpret_cast<int32x4_t>(vec3.value));
+    auto m3 = vtrn2q_s32(reinterpret_cast<int32x4_t>(vec2.value), reinterpret_cast<int32x4_t>(vec3.value));
+    vec0.value = reinterpret_cast<int8x16_t>(vtrn1q_s64(reinterpret_cast<int64x2_t>(m0), reinterpret_cast<int64x2_t>(m2)));
+    vec1.value = reinterpret_cast<int8x16_t>(vtrn1q_s64(reinterpret_cast<int64x2_t>(m1), reinterpret_cast<int64x2_t>(m3)));
+    vec2.value = reinterpret_cast<int8x16_t>(vtrn2q_s64(reinterpret_cast<int64x2_t>(m0), reinterpret_cast<int64x2_t>(m2)));
+    vec3.value = reinterpret_cast<int8x16_t>(vtrn2q_s64(reinterpret_cast<int64x2_t>(m1), reinterpret_cast<int64x2_t>(m3)));
 #endif
 }
 #endif
@@ -91,7 +93,7 @@ static void _sourceTransUnit4x4Pack4x4(const int8_t* srcStart, int8_t* dstStart,
         VecType1::save(dstStart + 1 * dstXStep, s1 + s2);
         VecType1::save(dstStart + 2 * dstXStep, s2 - s1);
         VecType1::save(dstStart + 3 * dstXStep, s3 - s1);
-        
+
         srcStart += srcZStep;
         dstStart += dstZStep;
     }
@@ -194,7 +196,7 @@ static void _sourceTransUnit4x4Pack16x16(const int8_t* srcStart, int8_t* dstStar
         VecType::save(dstStart + 1 * dstXStep, s1 + s2);
         VecType::save(dstStart + 2 * dstXStep, s2 - s1);
         VecType::save(dstStart + 3 * dstXStep, s3 - s1);
-        
+
         srcStart += srcZStep;
         dstStart += dstZStep;
     }
@@ -227,7 +229,7 @@ static void _destTransformUnit4x2(const float* srcStart, float* dstStart, size_t
         VecType::mla(m1, x1 - x2, c0);
         VecType::save(dstStart + dstXStep * 0, m0);
         VecType::save(dstStart + dstXStep * 1, m1);
-        
+
         srcStart += srcZStep;
         dstStart += dstZStep;
     }
@@ -247,7 +249,7 @@ static void _destTransformUnit4x3(const float* srcStart, float* dstStart, size_t
         VecType::save(dstStart + dstXStep * 0, m0);
         VecType::save(dstStart + dstXStep * 1, m1);
         VecType::save(dstStart + dstXStep * 2, m2);
-        
+
         srcStart += srcZStep;
         dstStart += dstZStep;
     }
@@ -354,7 +356,7 @@ bool WinogradInt8Helper::transformWeight(const Tensor* weightSrc, Tensor* weight
         dataDstOrigin = weightDst->host<int8_t>();
         memset(dataDstOrigin, 0, weightDst->size());
     }
-    
+
     bool overflow = false;
     for (int oz = 0; oz < oc; ++oz) {
         int oz4 = oz / UNIT, ozRemain = oz % UNIT;
@@ -412,7 +414,7 @@ bool WinogradInt8Helper::featureOverflow(const Tensor* input, int alphaY, int al
     } else if (alphaY == alphaX) {
         iter = limit2D.find(alphaY);
     }
-    
+
     bool overflow = (quantAttr->min < iter->second.first || quantAttr->max > iter->second.second);
     return overflow;
 }

+ 413 - 1
source/backend/cpu/compute/WinogradOptFunction.cpp

@@ -9,10 +9,12 @@
 #include "backend/cpu/compute/WinogradOptFunction.hpp"
 #include <cstring>
 #include <memory>
+#include <map>
 #include "core/Macro.h"
 #include "math/Vec.hpp"
-using Vec4 = MNN::Math::Vec<float, 4>;
+#include "common/MemoryFormater.h"
 
+using Vec4 = MNN::Math::Vec<float, 4>;
 #define DEFAULT_UNIT 8
 extern "C" {
 void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
@@ -74,6 +76,8 @@ void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, siz
 #endif
 
 namespace MNN {
+
+
 void WinogradFunction::productLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
                                    size_t length) {
     MNNWinogradMatrixProductLeft(S, B, M, w, h, k, length);
@@ -87,6 +91,115 @@ int WinogradFunction::getPreferNumber() {
     return DEFAULT_UNIT;
 }
 
+static void _sourceTransformUnit4x4Pack12(float* srcBlock, float* dstStart, size_t dstStep) {
+
+    // register number: (srcUnit + 1) * EPack/4 = 15
+    constexpr int Nh = 4; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    float* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // transpose 12x4 to 4x12
+        // register number : EPack
+        Vec4 s0 = Vec4::load(srcPtr + 0 * packCUnit);
+        Vec4 s3 = Vec4::load(srcPtr + 1 * packCUnit);
+        Vec4 s6 = Vec4::load(srcPtr + 2 * packCUnit);
+        Vec4 s9 = Vec4::load(srcPtr + 3 * packCUnit);
+        Vec4 s1 = Vec4::load(srcPtr + 4 * packCUnit);
+        Vec4 s4 = Vec4::load(srcPtr + 5 * packCUnit);
+        Vec4 s7 = Vec4::load(srcPtr + 6 * packCUnit);
+        Vec4 s10 = Vec4::load(srcPtr + 7 * packCUnit);
+        Vec4 s2 = Vec4::load(srcPtr + 8 * packCUnit);
+        Vec4 s5 = Vec4::load(srcPtr + 9 * packCUnit);
+        Vec4 s8 = Vec4::load(srcPtr + 10 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 11 * packCUnit);
+        Vec4::transpose4(s0, s3, s6, s9);
+        Vec4::transpose4(s1, s4, s7, s10);
+        Vec4::transpose4(s2, s5, s8, s11);
+
+        // to-optimize: interleave load and save in loop
+        // deal with pack when packCUnit is 8
+        Vec4::save(srcPtr + 0 * packCUnit, s0);
+        Vec4::save(srcPtr + 1 * packCUnit, s1);
+        Vec4::save(srcPtr + 2 * packCUnit, s2);
+        Vec4::save(srcPtr + 3 * packCUnit, s3);
+        Vec4::save(srcPtr + 4 * packCUnit, s4);
+        Vec4::save(srcPtr + 5 * packCUnit, s5);
+        Vec4::save(srcPtr + 6 * packCUnit, s6);
+        Vec4::save(srcPtr + 7 * packCUnit, s7);
+        Vec4::save(srcPtr + 8 * packCUnit, s8);
+        Vec4::save(srcPtr + 9 * packCUnit, s9);
+        Vec4::save(srcPtr + 10 * packCUnit, s10);
+        Vec4::save(srcPtr + 11 * packCUnit, s11);
+        srcPtr += loadTransposeStride;
+    }
+
+    // MNN_PRINT("winograd in BT*D*B, transpose, loadTransposeStride:%zu, dstStep:%zu\n", loadTransposeStride, dstStep);
+    // formatMatrix((const float*)srcBlock, {Nh, static_cast<int>(packCUnit), ePack});
+
+    srcPtr = srcBlock;
+    float* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+        // source transform D * B. register number : srcUnit * (EPack/4 + 1)
+        Vec4 s00 = Vec4::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s01 = Vec4::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s02 = Vec4::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s10 = Vec4::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s12 = Vec4::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s20 = Vec4::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s21 = Vec4::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s22 = Vec4::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s30 = Vec4::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s31 = Vec4::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s32 = Vec4::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        // dstStep =  ePack * pack * ic_4
+        auto ep0 = s00 - s20;
+        auto ep1 = s01 - s21;
+        auto ep2 = s02 - s22;
+        Vec4::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 + s20;
+        ep1 = s11 + s21;
+        ep2 = s12 + s22;
+        Vec4::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 - s10;
+        ep1 = s21 - s11;
+        ep2 = s22 - s12;
+        Vec4::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 - s10;
+        ep1 = s31 - s11;
+        ep2 = s32 - s12;
+        Vec4::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        // MNN_PRINT("\nwinograd in BT*D*B, iNh:0-3, i4c:%d\n", i4c);
+        // formatMatrix(dstPtr + 0 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 1 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 2 * dstStep , {ePack});
+        // formatMatrix(dstPtr + 3 * dstStep , {ePack});
+
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
+}
+
 static void _sourceTransformUnit4x4(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep) {
     Vec4 s0 = Vec4::load(srcBlock + 0 * srcStep);
     Vec4 s1 = Vec4::load(srcBlock + 1 * srcStep);
@@ -103,6 +216,7 @@ static void _sourceTransformUnit4x4(const float* srcBlock, float* dstStart, size
     Vec4::save(dstStart + 2 * dstStep, m2);
     Vec4::save(dstStart + 3 * dstStep, m3);
 }
+
 static void _destTransformUnit4x2(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep) {
     Vec4 s0 = Vec4::load(srcBlock + 0 * srcStep);
     Vec4 s1 = Vec4::load(srcBlock + 1 * srcStep);
@@ -163,6 +277,163 @@ static void _sourceTransformUnit8x8(const float* srcBlock, float* dstStart, size
     Vec4::save(dstStart + 5 * dstStep, m5);
     Vec4::save(dstStart + 6 * dstStep, m6);
     Vec4::save(dstStart + 7 * dstStep, m7);
+
+    // LOAD8;
+    // Vec4::save(dstStart + 0 * dstStep, s0);
+    // Vec4::save(dstStart + 1 * dstStep, s1);
+    // Vec4::save(dstStart + 2 * dstStep, s2);
+    // Vec4::save(dstStart + 3 * dstStep, s3);
+    // Vec4::save(dstStart + 4 * dstStep, s4);
+    // Vec4::save(dstStart + 5 * dstStep, s5);
+    // Vec4::save(dstStart + 6 * dstStep, s6);
+    // Vec4::save(dstStart + 7 * dstStep, s7);
+}
+
+static void _sourceTransformUnit8x8Pack12(float* srcBlock, float* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/4 = 27
+    // todo: impliment
+    constexpr int Nh = 8; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    float* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // transpose 12x4 to 4x12
+        // register number : EPack
+        Vec4 s0 = Vec4::load(srcPtr + 0 * packCUnit);
+        Vec4 s3 = Vec4::load(srcPtr + 1 * packCUnit);
+        Vec4 s6 = Vec4::load(srcPtr + 2 * packCUnit);
+        Vec4 s9 = Vec4::load(srcPtr + 3 * packCUnit);
+        Vec4 s1 = Vec4::load(srcPtr + 4 * packCUnit);
+        Vec4 s4 = Vec4::load(srcPtr + 5 * packCUnit);
+        Vec4 s7 = Vec4::load(srcPtr + 6 * packCUnit);
+        Vec4 s10 = Vec4::load(srcPtr + 7 * packCUnit);
+        Vec4 s2 = Vec4::load(srcPtr + 8 * packCUnit);
+        Vec4 s5 = Vec4::load(srcPtr + 9 * packCUnit);
+        Vec4 s8 = Vec4::load(srcPtr + 10 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 11 * packCUnit);
+        Vec4::transpose4(s0, s3, s6, s9);
+        Vec4::transpose4(s1, s4, s7, s10);
+        Vec4::transpose4(s2, s5, s8, s11);
+
+        // to-optimize: interleave load and save in loop
+        // deal with pack when packCUnit is 8
+        Vec4::save(srcPtr + 0 * packCUnit, s0);
+        Vec4::save(srcPtr + 1 * packCUnit, s1);
+        Vec4::save(srcPtr + 2 * packCUnit, s2);
+        Vec4::save(srcPtr + 3 * packCUnit, s3);
+        Vec4::save(srcPtr + 4 * packCUnit, s4);
+        Vec4::save(srcPtr + 5 * packCUnit, s5);
+        Vec4::save(srcPtr + 6 * packCUnit, s6);
+        Vec4::save(srcPtr + 7 * packCUnit, s7);
+        Vec4::save(srcPtr + 8 * packCUnit, s8);
+        Vec4::save(srcPtr + 9 * packCUnit, s9);
+        Vec4::save(srcPtr + 10 * packCUnit, s10);
+        Vec4::save(srcPtr + 11 * packCUnit, s11);
+        srcPtr += loadTransposeStride;
+    }
+
+    //     MNN_PRINT("winograd in BT*D*B, transpose, loadTransposeStride:%zu, dstStep:%zu\n", loadTransposeStride, dstStep);
+    // formatMatrix((const float*)srcBlock, {Nh, static_cast<int>(packCUnit), ePack});
+
+    srcPtr = srcBlock;
+    float* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+        Vec4 s00 = Vec4::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s01 = Vec4::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s02 = Vec4::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s10 = Vec4::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s12 = Vec4::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s20 = Vec4::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s21 = Vec4::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s22 = Vec4::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s30 = Vec4::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s31 = Vec4::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s32 = Vec4::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s40 = Vec4::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s41 = Vec4::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s42 = Vec4::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s50 = Vec4::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s51 = Vec4::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s52 = Vec4::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s60 = Vec4::load(srcPtr + 6 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s61 = Vec4::load(srcPtr + 6 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s62 = Vec4::load(srcPtr + 6 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s70 = Vec4::load(srcPtr + 7 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s71 = Vec4::load(srcPtr + 7 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s72 = Vec4::load(srcPtr + 7 * loadTransposeStride + 2 * packCUnit);
+
+
+        // to-try: reorder complicated commpute of 8x8
+        auto ep0 = s00 * 36.f - s20 * 49.f + s40 * 14.f - s60;
+        auto ep1 = s01 * 36.f - s21 * 49.f + s41 * 14.f - s61;
+        auto ep2 = s02 * 36.f - s22 * 49.f + s42 * 14.f - s62;
+        Vec4::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * 36.f - (s30 + s40) * 13.f + (s50 + s60);
+        ep1 = (s11 + s21) * 36.f - (s31 + s41) * 13.f + (s51 + s61);
+        ep2 = (s12 + s22) * 36.f - (s32 + s42) * 13.f + (s52 + s62);
+        Vec4::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s20 - s10) * 36.f + (s30 - s40) * 13.f + (s60 - s50);
+        ep1 = (s21 - s11) * 36.f + (s31 - s41) * 13.f + (s61 - s51);
+        ep2 = (s22 - s12) * 36.f + (s32 - s42) * 13.f + (s62 - s52);
+        Vec4::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 18.f + s20 * 9.f - s30 * 20.f - s40 * 10.f + s50 * 2.f + s60;
+        ep1 = s11 * 18.f + s21 * 9.f - s31 * 20.f - s41 * 10.f + s51 * 2.f + s61;
+        ep2 = s12 * 18.f + s22 * 9.f - s32 * 20.f - s42 * 10.f + s52 * 2.f + s62;
+        Vec4::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 9.f - s10 * 18.f + s30 * 20.f - s40 * 10.f - s50 * 2.f + s60;
+        ep1 = s21 * 9.f - s11 * 18.f + s31 * 20.f - s41 * 10.f - s51 * 2.f + s61;
+        ep2 = s22 * 9.f - s12 * 18.f + s32 * 20.f - s42 * 10.f - s52 * 2.f + s62;
+        Vec4::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 12.f + s20 * 4.f - s30 * 15.f - s40 * 5.f + s50 * 3.f + s60;
+        ep1 = s11 * 12.f + s21 * 4.f - s31 * 15.f - s41 * 5.f + s51 * 3.f + s61;
+        ep2 = s12 * 12.f + s22 * 4.f - s32 * 15.f - s42 * 5.f + s52 * 3.f + s62;
+        Vec4::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s20 * 4.f - s10 * 12.f + s30 * 15.f - s40 * 5.f - s50 * 3.f + s60;
+        ep1 = s21 * 4.f - s11 * 12.f + s31 * 15.f - s41 * 5.f - s51 * 3.f + s61;
+        ep2 = s22 * 4.f - s12 * 12.f + s32 * 15.f - s42 * 5.f - s52 * 3.f + s62;
+        Vec4::save(dstPtr + 6 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 6 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 6 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s30 * 49.f - s10 * 36.f - s50 * 14.f + s70;
+        ep1 = s31 * 49.f - s11 * 36.f - s51 * 14.f + s71;
+        ep2 = s32 * 49.f - s12 * 36.f - s52 * 14.f + s72;
+        Vec4::save(dstPtr + 7 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 7 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 7 * dstStep + 2 * packCUnit, ep2);
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
 }
 
 static void _destTransformUnit8x2(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep) {
@@ -277,6 +548,128 @@ static WinogradFunction::TransformFunc gProcUnit8[] = {
 };
 
 
+static void _sourceTransformUnit6x6Pack12(float* srcBlock, float* dstStart, size_t dstStep) {
+
+    // source transform D * B. register number : (srcUnit + 1) * EPack/4 = 21
+    constexpr int Nh = 6; // srcUnit
+    constexpr int ePack = 12;
+    constexpr size_t packCUnit = 4;
+    const size_t loadTransposeStride = packCUnit * ePack;
+    float* srcPtr = srcBlock;
+    for (int iNh = 0; iNh < Nh; ++iNh)
+    {
+        // transpose 12x4 to 4x12
+        // register number : EPack
+        Vec4 s0 = Vec4::load(srcPtr + 0 * packCUnit);
+        Vec4 s3 = Vec4::load(srcPtr + 1 * packCUnit);
+        Vec4 s6 = Vec4::load(srcPtr + 2 * packCUnit);
+        Vec4 s9 = Vec4::load(srcPtr + 3 * packCUnit);
+        Vec4 s1 = Vec4::load(srcPtr + 4 * packCUnit);
+        Vec4 s4 = Vec4::load(srcPtr + 5 * packCUnit);
+        Vec4 s7 = Vec4::load(srcPtr + 6 * packCUnit);
+        Vec4 s10 = Vec4::load(srcPtr + 7 * packCUnit);
+        Vec4 s2 = Vec4::load(srcPtr + 8 * packCUnit);
+        Vec4 s5 = Vec4::load(srcPtr + 9 * packCUnit);
+        Vec4 s8 = Vec4::load(srcPtr + 10 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 11 * packCUnit);
+        Vec4::transpose4(s0, s3, s6, s9);
+        Vec4::transpose4(s1, s4, s7, s10);
+        Vec4::transpose4(s2, s5, s8, s11);
+
+        // to-optimize: interleave load and save in loop
+        // deal with pack when packCUnit is 8
+        Vec4::save(srcPtr + 0 * packCUnit, s0);
+        Vec4::save(srcPtr + 1 * packCUnit, s1);
+        Vec4::save(srcPtr + 2 * packCUnit, s2);
+        Vec4::save(srcPtr + 3 * packCUnit, s3);
+        Vec4::save(srcPtr + 4 * packCUnit, s4);
+        Vec4::save(srcPtr + 5 * packCUnit, s5);
+        Vec4::save(srcPtr + 6 * packCUnit, s6);
+        Vec4::save(srcPtr + 7 * packCUnit, s7);
+        Vec4::save(srcPtr + 8 * packCUnit, s8);
+        Vec4::save(srcPtr + 9 * packCUnit, s9);
+        Vec4::save(srcPtr + 10 * packCUnit, s10);
+        Vec4::save(srcPtr + 11 * packCUnit, s11);
+        srcPtr += loadTransposeStride;
+    }
+
+    srcPtr = srcBlock;
+    float* dstPtr = dstStart;
+    for (int i4c = 0; i4c < packCUnit; ++i4c)
+    {
+
+        Vec4 s00 = Vec4::load(srcPtr + 0 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s01 = Vec4::load(srcPtr + 0 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s02 = Vec4::load(srcPtr + 0 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s10 = Vec4::load(srcPtr + 1 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s11 = Vec4::load(srcPtr + 1 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s12 = Vec4::load(srcPtr + 1 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s20 = Vec4::load(srcPtr + 2 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s21 = Vec4::load(srcPtr + 2 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s22 = Vec4::load(srcPtr + 2 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s30 = Vec4::load(srcPtr + 3 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s31 = Vec4::load(srcPtr + 3 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s32 = Vec4::load(srcPtr + 3 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s40 = Vec4::load(srcPtr + 4 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s41 = Vec4::load(srcPtr + 4 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s42 = Vec4::load(srcPtr + 4 * loadTransposeStride + 2 * packCUnit);
+
+        Vec4 s50 = Vec4::load(srcPtr + 5 * loadTransposeStride + 0 * packCUnit);
+        Vec4 s51 = Vec4::load(srcPtr + 5 * loadTransposeStride + 1 * packCUnit);
+        Vec4 s52 = Vec4::load(srcPtr + 5 * loadTransposeStride + 2 * packCUnit);
+
+        // to-try: reorder
+        auto ep0 = s00 * 4.f - s20 * 5.f + s40;
+        auto ep1 = s01 * 4.f - s21 * 5.f + s41;
+        auto ep2 = s02 * 4.f - s22 * 5.f + s42;
+        Vec4::save(dstPtr + 0 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 0 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 0 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 + s20) * (-4.f) + s30 + s40;
+        ep1 = (s11 + s21) * (-4.f) + s31 + s41;
+        ep2 = (s12 + s22) * (-4.f) + s32 + s42;
+        Vec4::save(dstPtr + 1 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 1 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 1 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = (s10 - s20) * (4.f) + s40 - s30;
+        ep1 = (s11 - s21) * (4.f) + s41 - s31;
+        ep2 = (s12 - s22) * (4.f) + s42 - s32;
+        Vec4::save(dstPtr + 2 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 2 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 2 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * (-2.f) - s20 + s30 * 2.f + s40;
+        ep1 = s11 * (-2.f) - s21 + s31 * 2.f + s41;
+        ep2 = s12 * (-2.f) - s22 + s32 * 2.f + s42;
+        Vec4::save(dstPtr + 3 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 3 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 3 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 2.f - s20 - s30 * 2.f + s40;
+        ep1 = s11 * 2.f - s21 - s31 * 2.f + s41;
+        ep2 = s12 * 2.f - s22 - s32 * 2.f + s42;
+        Vec4::save(dstPtr + 4 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 4 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 4 * dstStep + 2 * packCUnit, ep2);
+
+        ep0 = s10 * 4.f - s30 * 5.f + s50;
+        ep1 = s11 * 4.f - s31 * 5.f + s51;
+        ep2 = s12 * 4.f - s32 * 5.f + s52;
+        Vec4::save(dstPtr + 5 * dstStep + 0 * packCUnit, ep0);
+        Vec4::save(dstPtr + 5 * dstStep + 1 * packCUnit, ep1);
+        Vec4::save(dstPtr + 5 * dstStep + 2 * packCUnit, ep2);
+
+        srcPtr += ePack;
+        dstPtr += ePack;
+    }
+}
+
 #define LOAD6                                     \
 Vec4 s0 = Vec4::load(srcBlock + 0 * srcStep); \
 Vec4 s1 = Vec4::load(srcBlock + 1 * srcStep); \
@@ -305,6 +698,7 @@ static void _sourceTransformUnit6x6(const float* srcBlock, float* dstStart, size
     Vec4::save(dstStart + 5 * dstStep, m5);
 }
 
+
 static void _destTransformUnit6x5(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep) {
     Vec4 s0 = Vec4::load(srcBlock + 0 * srcStep);
     Vec4 s1 = Vec4::load(srcBlock + 1 * srcStep);
@@ -402,6 +796,24 @@ WinogradFunction::TransformFunc WinogradFunction::chooseSourceTransform(int k, i
     return nullptr;
 }
 
+WinogradFunction::TransformPackFunc WinogradFunction::chooseWinoSourceTransformPack(int k, int w, int ePack, int lPack, int packCUnit) {
+    if (ePack == 12 && lPack == 1 && packCUnit == 4) {
+        if (k == 4 && w == 4) {
+            return _sourceTransformUnit4x4Pack12;
+        }
+        if (k == 6 && w == 6) {
+            return _sourceTransformUnit6x6Pack12;
+        }
+        if (k == 8 && w == 8) {
+            return _sourceTransformUnit8x8Pack12;
+        }
+        // other packing size
+    }
+    MNN_ERROR("Can not find function for ePack:%d, packCUnit:%d\n", ePack, packCUnit);
+    MNN_ASSERT(false);
+    return nullptr;
+}
+
 WinogradFunction::TransformFunc WinogradFunction::chooseDestTransform(int k, int h) {
     if (8 == k) {
         if (h <= 1 || h > 7) {

+ 2 - 0
source/backend/cpu/compute/WinogradOptFunction.hpp

@@ -21,10 +21,12 @@ public:
     static int getPreferNumber();
 
     typedef void (*TransformFunc)(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep);
+    typedef void (*TransformPackFunc)(float* srcBlock, float* dstStart, size_t dstStep);
 
     /*Use the generator with interp 0.5*/
     static TransformFunc chooseSourceTransform(int k, int w);
     static TransformFunc chooseDestTransform(int k, int h);
+    static TransformPackFunc chooseWinoSourceTransformPack(int k, int h, int ePack, int lPack, int packCUnit);
 };
 } // namespace MNN
 

+ 196 - 51
source/backend/cpu/x86_x64/AVX2Backend.cpp

@@ -23,6 +23,10 @@
 #include "backend/cpu/CPUTensorConvert.hpp"
 #include "core/OpCommonUtils.hpp"
 #include "backend/cpu/CPUCast.hpp"
+extern "C" {
+void MNNInt8ToUInt8(void* ptr, int count);
+void MNNUInt8ToInt8(void* ptr, int count);
+}
 
 namespace MNN {
 bool AVX2Backend::isValid() {
@@ -31,16 +35,173 @@ bool AVX2Backend::isValid() {
 
 AVX2Backend::AVX2Backend(const CPURuntime* runtime, size_t flags) : CPUBackend(runtime, BackendConfig::Precision_Low, MNN_FORWARD_CPU_EXTENSION, flags) {
     mCoreFunctions = AVX2Functions::get();
+    mInt8CoreFunctions = AVX2Functions::getInt8();
 }
 
 AVX2Backend::~AVX2Backend() {
     // nothing to do
 }
-
 // TODO: Move to functions
-static void _CopyC4ToC8(void* dstPtr, const void* srcPtr, int channelC4, int area) {
-    float* dst = static_cast<float*>(dstPtr);
-    const float* src = static_cast<const float*>(srcPtr);
+
+static void _CopyC16ToC4_int8(float* dstO, const float* srcO, int channelC4, int area) {
+    auto dst = (int32_t*)dstO;
+    auto src = (int32_t*)srcO;
+    int c8 = channelC4 / 4;
+    int cR = channelC4 % 4;
+    for (int z=0; z<c8; ++z) {
+        auto s0 = dst + 4 * z * area;
+        auto s1 = dst + (4 * z + 1) * area;
+        auto s2 = dst + (4 * z + 2) * area;
+        auto s3 = dst + (4 * z + 3) * area;
+        auto d = src + z * area * 4;
+        for (int x=0; x<area; ++x) {
+            *s0 = d[0];
+            *s1 = d[1];
+            *s2 = d[2];
+            *s3 = d[3];
+            s0++;
+            s1++;
+            s2++;
+            s3++;
+            d+=4;
+        }
+    }
+    if (cR > 0) {
+        auto s0 = dst + 4 * c8 * area;
+        auto d = src + c8 * area * 4;
+        for (int x=0; x<area; ++x) {
+            for (int v=0; v<cR; ++v) {
+                s0[v * area] = d[v];
+            }
+            s0++;
+            d+=4;
+        }
+    }
+}
+
+
+static void _CopyC4ToC16_int8(float* dstO, const float* srcO, int channelC4, int area) {
+    auto dst = (int32_t*)dstO;
+    auto src = (int32_t*)srcO;
+    int c8 = channelC4 / 4;
+    int cR = channelC4 % 4;
+    for (int z=0; z<c8; ++z) {
+        auto s0 = src + 4 * z * area;
+        auto s1 = src + (4 * z + 1) * area;
+        auto s2 = src + (4 * z + 2) * area;
+        auto s3 = src + (4 * z + 3) * area;
+        auto d = dst + z * area * 4;
+        for (int x=0; x<area; ++x) {
+            d[0] = *s0;
+            d[1] = *s1;
+            d[2] = *s2;
+            d[3] = *s3;
+            s0 ++;
+            s1 ++;
+            s2 ++;
+            s3 ++;
+            d += 4;
+        }
+    }
+    if (cR > 0) {
+        auto s0 = src + 4 * c8 * area;
+        auto d = dst + c8 * area * 4;
+        for (int x=0; x<area; ++x) {
+            for (int v=0; v<cR; ++v) {
+                d[v] = s0[v * area];
+            }
+            for (int v=cR; v<4; ++v) {
+                d[v] = 0;
+            }
+            s0 += 4;
+            d += 16;
+        }
+    }
+}
+
+static void _CopyC4ToC16(float* dst, const float* src, int channelC4, int area) {
+    int c8 = channelC4 / 4;
+    int cR = channelC4 % 4;
+    for (int z=0; z<c8; ++z) {
+        auto s0 = src + 4 * z * area * 4;
+        auto s1 = src + (4 * z + 1) * area * 4;
+        auto s2 = src + (4 * z + 2) * area * 4;
+        auto s3 = src + (4 * z + 3) * area * 4;
+        auto d = dst + z * area * 16;
+        for (int x=0; x<area; ++x) {
+            auto v0 = _mm_loadu_ps(s0);
+            auto v1 = _mm_loadu_ps(s1);
+            auto v2 = _mm_loadu_ps(s2);
+            auto v3 = _mm_loadu_ps(s3);
+            _mm_storeu_ps(d + 0, v0);
+            _mm_storeu_ps(d + 4, v1);
+            _mm_storeu_ps(d + 8, v2);
+            _mm_storeu_ps(d + 12, v3);
+            s0 += 4;
+            s1 += 4;
+            s2 += 4;
+            s3 += 4;
+            d += 16;
+        }
+    }
+    if (cR > 0) {
+        auto s0 = src + 4 * c8 * area * 4;
+        auto d = dst + c8 * area * 16;
+        auto v1 = _mm_setzero_ps();
+        for (int x=0; x<area; ++x) {
+            for (int v=0; v<cR; ++v) {
+                auto v0 = _mm_loadu_ps(s0 + v * area * 4);
+                _mm_storeu_ps(d + 4 * v, v0);
+            }
+            for (int v=cR; v<4; ++v) {
+                _mm_storeu_ps(d + 4 * v, v1);
+            }
+            s0 += 4;
+            d += 16;
+        }
+    }
+}
+
+static void _CopyC16ToC4(float* dst, const float* src, int channelC4, int area) {
+    int c8 = channelC4 / 4;
+    int cR = channelC4 % 4;
+    for (int z=0; z<c8; ++z) {
+        auto s0 = dst + 4 * z * area * 4;
+        auto s1 = dst + (4 * z + 1) * area * 4;
+        auto s2 = dst + (4 * z + 2) * area * 4;
+        auto s3 = dst + (4 * z + 3) * area * 4;
+        auto d = src + z * area * 16;
+        for (int x=0; x<area; ++x) {
+            auto v0 = _mm_loadu_ps(d);
+            auto v1 = _mm_loadu_ps(d + 4);
+            auto v2 = _mm_loadu_ps(d + 8);
+            auto v3 = _mm_loadu_ps(d + 12);
+            _mm_storeu_ps(s0, v0);
+            _mm_storeu_ps(s1, v1);
+            _mm_storeu_ps(s2, v2);
+            _mm_storeu_ps(s3, v3);
+            s0 += 4;
+            s1 += 4;
+            s2 += 4;
+            s3 += 4;
+            d+= 16;
+        }
+    }
+    if (cR > 0) {
+        auto s0 = dst + 4 * c8 * area * 4;
+        auto d = src + c8 * area * 16;
+        for (int x=0; x<area; ++x) {
+            for (int v=0; v<cR; ++v) {
+                auto v0 = _mm_loadu_ps(d + v * 4);
+                _mm_storeu_ps(s0 + 4 * v * area, v0);
+            }
+            s0 += 4;
+            d+= 16;
+        }
+    }
+}
+
+static void _CopyC4ToC8(float* dst, const float* src, int channelC4, int area) {
     int c8 = channelC4 / 2;
     int cR = channelC4 % 2;
     for (int z=0; z<c8; ++z) {
@@ -71,9 +232,7 @@ static void _CopyC4ToC8(void* dstPtr, const void* srcPtr, int channelC4, int are
     }
 }
 
-static void _CopyC8ToC4(void* dstPtr, const void* srcPtr, int channelC4, int area) {
-    float* dst = static_cast<float*>(dstPtr);
-    const float* src = static_cast<const float*>(srcPtr);
+static void _CopyC8ToC4(float* dst, const float* src, int channelC4, int area) {
     int c8 = channelC4 / 2;
     int cR = channelC4 % 2;
     for (int z=0; z<c8; ++z) {
@@ -102,9 +261,9 @@ static void _CopyC8ToC4(void* dstPtr, const void* srcPtr, int channelC4, int are
     }
 }
 
-static void _CopyC4ToC8_int8(void* dstPtr, const void* srcPtr, int channelC4, int area) {
-    int8_t* dst = static_cast<int8_t*>(dstPtr);
-    const int8_t* src = static_cast<const int8_t*>(srcPtr);
+static void _CopyC4ToC8_int8(float* dstPtr, const float* srcPtr, int channelC4, int area) {
+    int8_t* dst = (int8_t*)(dstPtr);
+    const int8_t* src = (const int8_t*)(srcPtr);
     int c8 = channelC4 / 2;
     int cR = channelC4 % 2;
     for (int z=0; z<c8; ++z) {
@@ -131,9 +290,9 @@ static void _CopyC4ToC8_int8(void* dstPtr, const void* srcPtr, int channelC4, in
     }
 }
 
-static void _CopyC8ToC4_int8(void* dstPtr, const void* srcPtr, int channelC4, int area) {
-    int8_t* dst = static_cast<int8_t*>(dstPtr);
-    const int8_t* src = static_cast<const int8_t*>(srcPtr);
+static void _CopyC8ToC4_int8(float* dstPtr, const float* srcPtr, int channelC4, int area) {
+    int8_t* dst = (int8_t*)(dstPtr);
+    const int8_t* src = (const int8_t*)(srcPtr);
     int c8 = channelC4 / 2;
     int cR = channelC4 % 2;
     for (int z=0; z<c8; ++z) {
@@ -162,30 +321,15 @@ static void _CopyC8ToC4_int8(void* dstPtr, const void* srcPtr, int channelC4, in
 Execution* AVX2Backend::onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
                                   const MNN::Op* op) {
     for (auto t : outputs) {
-        if (t->getType().code != halide_type_float) {
+        if (t->getType().code != halide_type_float && t->getType().bits != 8) {
             return nullptr;
         }
-    }
-    auto inputQuantInfo = OpCommonUtils::getQuantInfo(inputs);
-    auto ouputQuantInfo = OpCommonUtils::getQuantInfo(outputs);
-    halide_type_t quantType = halide_type_of<float>();
-    if (inputQuantInfo.first) {
-        if (!ouputQuantInfo.first && !outputs.empty()) {
-            quantType = outputs[0]->getType();
-        } else {
-            quantType = TensorUtils::DataTypeToHalideType(inputQuantInfo.second);
+        if (t->getType().code == halide_type_uint) {
+            return nullptr;
         }
     }
-    auto originType = outputs.empty() ? halide_type_of<float>() : outputs[0]->getType();
-    auto runType = getRunType(op, quantType, originType);
-    if (runType == halide_type_of<int8_t>()) {
-        return nullptr;
-    }
-    if (op->type() == OpType_Raster) {
-        return new CPURaster(this);
-    }
     bool originCreate = OpCommonUtils::opCompabilityForLowp(op);
-    if (originCreate || op->type() == OpType_Softmax || op->type() == OpType_Reduction) {
+    if (originCreate || op->type() == OpType_Softmax || op->type() == OpType_Reduction || op->type() == OpType_ConvInt8 || op->type() == OpType_DepthwiseConvInt8 || op->type() == OpType_FloatToInt8 || op->type() == OpType_Int8ToFloat) {
         return CPUBackend::onCreate(inputs, outputs, op);
     }
     return nullptr;
@@ -227,8 +371,10 @@ void AVX2Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor)
             default:
                 break;
         }
-        wrapTensor.reset(Tensor::create(srcTensor->shape(), dstTensor->getType(), nullptr, dimType));
-        auto code = CPUCastCreator::cast(srcTensor, wrapTensor.get());
+        wrapTensor.reset(Tensor::createDevice(srcTensor->shape(), dstTensor->getType(), dimType));
+        wrapTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(getTensorSize(wrapTensor.get()) * wrapTensor->getType().bytes(), MNN_MEMORY_ALIGN_DEFAULT);
+        TensorUtils::getDescribe(wrapTensor.get())->memoryType = Tensor::InsideDescribe::MEMORY_HOST;
+        auto code = CPUCastCreator::cast(srcTensor, wrapTensor.get(), this);
         if (NO_ERROR != code) {
             MNN_ERROR("Error in CPUBackend::onCopyBuffer:cast\n");
         }
@@ -257,6 +403,20 @@ void AVX2Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor)
         return;
     }
     if (source == MNN_DATA_FORMAT_NC4HW4 && dest == MNN_DATA_FORMAT_NC4HW4) {
+        auto outF = _CopyC8ToC4;
+        auto inF = _CopyC4ToC8;
+        if (ob.type.bytes() == 1) {
+            outF = _CopyC8ToC4_int8;
+            inF = _CopyC4ToC8_int8;
+        }
+        if (mCoreFunctions->pack == 16) {
+            outF = _CopyC16ToC4;
+            inF = _CopyC4ToC16;
+            if (ob.type.bytes() == 1) {
+                outF = _CopyC16ToC4_int8;
+                inF = _CopyC4ToC16_int8;
+            }
+        }
         // NC4HW4 <-> NC8HW8
         if (1 == srcTensor->dimensions()) {
             ::memcpy(dstTensor->host<void>(), srcTensor->host<void>(), srcTensor->length(0) * srcTensor->getType().bytes());
@@ -266,25 +426,10 @@ void AVX2Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor)
         int area = std::get<1>(dims) * std::get<0>(dims);
         int channel = std::get<2>(dims);
         auto c4 = UP_DIV(channel, 4);
-        auto c8 = UP_DIV(channel, mCoreFunctions->pack);
-        auto c8toc4 = _CopyC8ToC4, c4toc8 = _CopyC4ToC8;
-        switch (ob.type.bytes()) {
-            case 1:
-                c8toc4 = _CopyC8ToC4_int8;
-                c4toc8 = _CopyC4ToC8_int8;
-                break;
-            case 4:
-                c8toc4 = _CopyC8ToC4;
-                c4toc8 = _CopyC4ToC8;
-                break;
-            default:
-                MNN_ASSERT(false);
-                break;
-        }
         if (srcType == MNN_FORWARD_CPU_EXTENSION) {
-            c8toc4(dstTensor->host<void>(), srcTensor->host<void>(), c4, area);
+            outF(dstTensor->host<float>(), srcTensor->host<float>(), c4, area);
         } else {
-            c4toc8(dstTensor->host<void>(), srcTensor->host<void>(), c4, area);
+            inF(dstTensor->host<float>(), srcTensor->host<float>(), c4, area);
         }
         return;
     }

+ 1 - 0
source/backend/cpu/x86_x64/AVX2Backend.hpp

@@ -5,6 +5,7 @@
 //  Created by MNN on 2021/05/16.
 //  Copyright © 2018, Alibaba Group Holding Limited
 //
+
 #ifndef AVX2Backend_hpp
 #define AVX2Backend_hpp
 

+ 31 - 23
source/backend/cpu/x86_x64/AVX2Functions.cpp

@@ -1,3 +1,11 @@
+//
+//  AVX2Functions.cpp
+//  MNN
+//
+//  Created by MNN on b'2021/05/17'.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
 #include "AVX2Functions.hpp"
 #include "AVX2Backend.hpp"
 #include "avx/FunctionSummary.hpp"
@@ -12,6 +20,7 @@ struct MatMulPackParam {
 
 static MatMulPackParam gPackInfo;
 static CoreFunctions* gAVX2CoreFunctions = nullptr;
+static CoreInt8Functions* gAVX2CoreInt8Functions = nullptr;
 static void _MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
     *eP = gPackInfo.eP;
     *lP = gPackInfo.lP;
@@ -21,42 +30,27 @@ static void _MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
 bool AVX2Functions::init(int cpuFlags) {
     gAVX2CoreFunctions = new CoreFunctions;
     auto coreFunction = gAVX2CoreFunctions;
+    gAVX2CoreInt8Functions = new CoreInt8Functions;
     // Init default functions
     *coreFunction = *MNNGetCoreFunctions();
-
+    *gAVX2CoreInt8Functions = *MNNGetInt8CoreFunctions();
+    _AVX_MNNInt8FunctionInit(gAVX2CoreInt8Functions);
     // Init AVX2
     coreFunction->MNNGetMatMulPackMode = _MNNGetMatMulPackMode;
     gPackInfo.eP                    = 24;
     gPackInfo.lP                    = 1;
     gPackInfo.hP                    = 4;
-    coreFunction->pack = 8;
-    coreFunction->MNNPackCUnit = _AVX_MNNPackCUnit;
-    coreFunction->MNNUnpackCUnit = _AVX_MNNUnpackCUnit;
-    coreFunction->MNNPackCUnitTranspose = _AVX_MNNPackCUnitTranspose;
-    coreFunction->MNNUnpackCUnitTranspose = _AVX_MNNUnpackCUnitTranspose;
-    coreFunction->MNNCopyC4WithStride = _AVX_MNNCopyC4WithStride;
-    coreFunction->MNNAddC4WithStride = _AVX_MNNAddC4WithStride;
-    coreFunction->MNNScaleAndAddBias = _AVX_MNNScaleAndAddBias;
-    coreFunction->MNNMatrixAdd          = _AVX_MNNMatrixAdd;
-    coreFunction->MNNMatrixSub          = _AVX_MNNMatrixSub;
+    _AVX_ReorderInit(coreFunction);
+
     coreFunction->MNNPackedMatMul       = _AVX_MNNPackedMatMul;
     coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain;
     coreFunction->MNNPackC4ForMatMul_A  = _AVX_MNNPackC4ForMatMul_A;
     coreFunction->MNNPackForMatMul_B    = _AVX_MNNPackForMatMul_B;
-    coreFunction->MNNConvRunForUnitDepthWise = _AVX_MNNConvRunForUnitDepthWise;
-    coreFunction->MNNConvRunForLineDepthwise = _AVX_MNNConvRunForLineDepthwise;
-    coreFunction->MNNAxByClampBroadcastUnit = _AVX_MNNAxByClampBroadcastUnit;
     coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1;
     coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1;
-    coreFunction->MNNStrassenMergeCFunction = _AVX_MNNStrassenMergeCFunction;
-    coreFunction->MNNMultiAndDestTransformCommon23 = _AVX_MNNMultiAndDestTransformCommon23;
-    coreFunction->MNNSourceTransformCommonF23 = _AVX_MNNSourceTransformCommonF23;
-    coreFunction->MNNConvDwF23MulTransUnit = _AVX_MNNConvDwF23MulTransUnit;
-    coreFunction->MNNReluWithSlopeChannel = _AVX_MNNReluWithSlopeChannel;
-    coreFunction->MNNDeconvRunForLineDepthwise = _AVX_MNNDeconvRunForLineDepthwise;
-    coreFunction->MNNDeconvRunForUnitDepthWise = _AVX_MNNDeconvRunForUnitDepthWise;
-    coreFunction->MNNGridSampleInterp = _AVX_MNNGridSampleInterp;
-    // For Pooling / Binary
+
+    // For Packed Functions
+    coreFunction->pack = 8;
     _AVX_ExtraInit(coreFunction);
     // Winograd
     _AVX_WinogradInit(coreFunction);
@@ -65,6 +59,7 @@ bool AVX2Functions::init(int cpuFlags) {
         coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA;
         coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1FMA;
         coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1FMA;
+        _AVX_ExtraInitFMA(coreFunction);
     }
 #ifdef MNN_AVX512
     if ((cpuFlags & libyuv::kCpuHasAVX512VNNI)
@@ -75,6 +70,10 @@ bool AVX2Functions::init(int cpuFlags) {
         || (cpuFlags & libyuv::kCpuHasAVX512VPOPCNTDQ)
         || (cpuFlags & libyuv::kCpuHasAVX512VBMI2)
         ) {
+        coreFunction->pack = 16;
+        _AVX512_ReorderInit(coreFunction);
+        _AVX512_ExtraInit(coreFunction);
+        _AVX512_WinogradInit(coreFunction);
         coreFunction->MNNPackForMatMul_B    = _AVX512_MNNPackForMatMul_B;
         coreFunction->MNNPackC4ForMatMul_A  = _AVX512_MNNPackC8ForMatMul_A;
         coreFunction->MNNPackedMatMul = _AVX512_MNNPackedMatMul;
@@ -83,10 +82,19 @@ bool AVX2Functions::init(int cpuFlags) {
         gPackInfo.hP                    = 8;
         gPackInfo.lP                    = 1;
     }
+#ifdef MNN_AVX512_VNNI
+    if (cpuFlags & libyuv::kCpuHasAVX512VNNI) {
+        _AVX512_MNNInt8FunctionInit(gAVX2CoreInt8Functions);
+    }
+#endif
 #endif
     return true;
 }
 CoreFunctions* AVX2Functions::get() {
     return gAVX2CoreFunctions;
 }
+CoreInt8Functions* AVX2Functions::getInt8() {
+    return gAVX2CoreInt8Functions;
+}
+
 };

+ 10 - 0
source/backend/cpu/x86_x64/AVX2Functions.hpp

@@ -1,3 +1,11 @@
+//
+//  AVX2Functions.hpp
+//  MNN
+//
+//  Created by MNN on b'2021/05/17'.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
 #ifndef AVX2Functions_hpp
 #define AVX2Functions_hpp
 #include <stdint.h>
@@ -5,6 +13,7 @@
 #include <string.h>
 #include "core/Macro.h"
 #include "backend/cpu/compute/CommonOptFunction.h"
+#include "backend/cpu/compute/Int8FunctionsOpt.h"
 #include "cpu_id.h"
 
 namespace MNN {
@@ -12,6 +21,7 @@ class AVX2Functions {
 public:
     static bool init(int flags);
     static CoreFunctions* get();
+    static CoreInt8Functions* getInt8();
 };
 };
 

+ 1 - 1
source/backend/cpu/x86_x64/CMakeLists.txt

@@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(X86_64)|(x64)|(X64)|(amd64)|(AMD64)
         if (MNN_AVX512)
             FILE(GLOB MNN_AVX512_SRC ${CMAKE_CURRENT_LIST_DIR}/avx512/*)
             add_library(MNNAVX512 OBJECT ${MNN_AVX512_SRC})
-            target_compile_options(MNNAVX512 PRIVATE -DMNN_USE_SSE)
+            target_compile_options(MNNAVX512 PRIVATE -DMNN_USE_SSE -DMNN_X86_USE_ASM)
             if (MNN_AVX512_VNNI)
                 target_compile_options(MNNAVX512 PRIVATE -m64 -mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma -mavx512vnni -DMNN_AVX512_VNNI)
             else()

+ 10 - 39
source/backend/cpu/x86_x64/FunctionDispatcher.cpp

@@ -23,22 +23,13 @@
 #include <x86intrin.h>
 #endif
 
-bool MNNReorder4x4ByPlatform(float* dst, size_t number) {
-    return _SSE_MNNReorder4x4ByPlatform(dst, number);
-}
-
 struct FunctionGroup {
     int tileNumber                                                                               = 8;
     int eP                                                                                       = 12;
     int lP                                                                                       = 1;
     int hP                                                                                       = 4;
-    void (*MNNExpC8)(float* dest, const float* source, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
+    void (*MNNExpC8)(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
     void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax;
-    void (*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
-                       ssize_t maxValue, ssize_t zeroPoint) = _SSE_MNNFloat2Int8;
-    void (*MNNInt8ScaleToFloat)(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) = _SSE_MNNInt8ScaleToFloat;
-    void (*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) = _SSE_MNNComputeMatMulForE_1;
-    void (*MNNReluWithSlopeChannel)(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) = _SSE_MNNReluWithSlopeChannel;
     void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size) = _SSE_MNNReluInt8;
     void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish;
     void (*MNNGelu)(float* dst, const float* src, size_t size) = _SSE_MNNGelu;
@@ -57,24 +48,21 @@ void MNNFunctionInit() {
     auto coreFunction = MNN::MNNGetCoreFunctions();
     if (cpuFlags & libyuv::kCpuHasSSSE3) {
         coreFunction->MNNGetMatMulPackMode = _SSEMNNGetMatMulPackMode;
-        coreFunction->MNNMatrixAdd          = _SSE_MNNMatrixAdd;
-        coreFunction->MNNMatrixSub          = _SSE_MNNMatrixSub;
         coreFunction->MNNPackedMatMul       = _SSE_MNNPackedMatMul;
         coreFunction->MNNPackedMatMulRemain = _SSE_MNNPackedMatMulRemain;
         coreFunction->MNNPackC4ForMatMul_A  = _SSE_MNNPackC4ForMatMul_A;
         coreFunction->MNNPackForMatMul_B    = _SSE_MNNPackForMatMul_B;
-        coreFunction->MNNConvRunForLineDepthwise = _SSE_MNNConvRunForLineDepthwise;
-        coreFunction->MNNAxByClampBroadcastUnit = _SSE_MNNAxByClampBroadcastUnit;
-        coreFunction->MNNComputeMatMulForE_1 = _SSE_MNNComputeMatMulForE_1;
     }
     if (cpuFlags & libyuv::kCpuHasAVX2) {
         MNN::AVX2Functions::init(cpuFlags);
         gFunc.MNNExpC8 = _AVX_MNNExpC8;
         gFunc.MNNSoftmax = _AVX_MNNSoftmax;
         gFunc.MNNGelu = _AVX_MNNGelu;
+        if (cpuFlags & libyuv::kCpuHasFMA3) {
+            gFunc.MNNGelu = _AVX_MNNGeluFMA;
+            gFunc.MNNExpC8 = _AVX_MNNExpC8FMA;
+        }
         gFunc.MNNNorm = _AVX_MNNNorm;
-        gFunc.MNNFloat2Int8 = _AVX_MNNFloat2Int8;
-        gFunc.MNNInt8ScaleToFloat = _AVX_MNNInt8ScaleToFloat;
     }
 }
 
@@ -82,21 +70,12 @@ void MNNInt8FunctionInit() {
     auto cpuFlags = libyuv::InitCpuFlags();
     auto core = MNN::MNNGetInt8CoreFunctions();
     if (cpuFlags & libyuv::kCpuHasSSSE3) {
+        core->MNNFloat2Int8 = _SSE_MNNFloat2Int8;
+        core->MNNInt8ScaleToFloat = _SSE_MNNInt8ScaleToFloat;
         core->Int8GemmKernel = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
         core->Int8GemmKernelFast = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
         core->ConvDepthwiseLineInt8 = _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit;
     }
-    if (cpuFlags & libyuv::kCpuHasAVX2) {
-        core->Int8GemmKernel = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit;
-        core->Int8GemmKernelFast = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast;
-        core->ConvDepthwiseLineInt8 = _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit;
-    }
-#ifdef MNN_AVX512_VNNI
-    if (cpuFlags & libyuv::kCpuHasAVX512VNNI) {
-        core->Int8GemmKernel = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit;
-        core->Int8GemmKernelFast = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit;
-    }
-#endif
 }
 
 // ========= CommonOptFunction.cpp ===========
@@ -110,7 +89,7 @@ void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size
 }
 
 void MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
-    return gFunc.MNNReluWithSlopeChannel(dst, src, slope, sizeQuad, depthQuad);
+    return _SSE_MNNReluWithSlopeChannel(dst, src, slope, sizeQuad, depthQuad);
 }
 
 void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size) {
@@ -125,16 +104,8 @@ void MNNGelu(float* dst, const float* src, size_t size) {
     return gFunc.MNNGelu(dst, src, size);
 }
 
-void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
-                   ssize_t maxValue, ssize_t zeroPoint) {
-    return gFunc.MNNFloat2Int8(src, dst, sizeQuad, scalep, minValue, maxValue, zeroPoint);
-}
-void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) {
-    return gFunc.MNNInt8ScaleToFloat(dst, src, scale, size, zeroPoint);
-}
-
-void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8) {
-    gFunc.MNNExpC8(dest, source, parameters, countC8);
+void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) {
+    gFunc.MNNExpC8(dest, source, offset, parameters, countC8);
 }
 
 void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {

File diff suppressed because it is too large
+ 0 - 1085
source/backend/cpu/x86_x64/avx/CommonOptFunction.cpp


+ 6 - 37
source/backend/cpu/x86_x64/avx/FunctionSummary.hpp

@@ -33,71 +33,40 @@
 
 // ========= CommonOptFunction.cpp ===========
 extern "C" {
-
-void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters);
-void _AVX_MNNGemmFloatCommon_4(float* dst, const float* src, const float* weight, size_t src_depth_quad,
-                               size_t dst_step, size_t dst_depth_quad, size_t width, size_t weight_depth_offset);
-void _AVX_MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad,
-                             size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset);
-void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
-                       size_t bStride, size_t height);
-void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
-                       size_t bStride, size_t height);
-void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride,
-                                    size_t length, size_t hSub);
-
 void _AVX_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter,
                           const float* postParameters, const float* bias);
 void _AVX_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
                                 const float* postParameters, const float* bias);
 void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
 
-void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
-                                size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
-                                     size_t srcHStep, size_t dstHStep);
-void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst);
-void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst);
-
-void _AVX_MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8);
+void _AVX_MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8);
 void _AVX_MNNSoftmax(float* dest, const float* source, size_t size);
 void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint);
 void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, ssize_t zeroPoint);
 void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
 void _AVX_MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
-
 void _AVX_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
 
 void _AVX_MNNGetMatMulPackMode_BF16(int* eP, int *lP, int* hP);
 void _AVX_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
 void _AVX_MNNPackedSparseMatMul(float* C, const float* A, const float* B, unsigned int* NNZMap, int* dataOffsetMap, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
-void _AVX_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad);
 void _AVX_MNNComputeMatMulForH_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
 
+void _AVX_ReorderInit(void* functions);
+void _AVX_MNNInt8FunctionInit(void* functions);
 void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
-void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub);
-void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
-                                     size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
-void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter);
-void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu);
-void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter);
 
 void _AVX_ExtraInit(void* functions);
 void _AVX_WinogradInit(void* functions);
-void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
-void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
-void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
-void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
-                                       size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
-void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
-                                       size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
 
 void _AVX_MNNGelu(float *dst, const float *src, size_t size);
 void _AVX_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size);
 
-void _AVX_MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
-
+void _AVX_MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP);
+void _AVX_MNNPackedSparseMatMulEpx1EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap);
+void _AVX_MNNPackedSparseMatMulEpx4EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap);
 }

File diff suppressed because it is too large
+ 702 - 947
source/backend/cpu/x86_x64/avx/GemmInt8.cpp


+ 776 - 0
source/backend/cpu/x86_x64/avx/GemmSparse.cpp

@@ -0,0 +1,776 @@
+//
+//  GemmCommon.cpp
+//  MNN
+//
+//  Created by MNN on 2021/07/28.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+
+#include "GemmCommon.hpp"
+#include "FunctionSummary.hpp"
+#include "Vec8.hpp"
+#include "core/Macro.h"
+
+#ifdef MNN_X86_USE_ASM
+extern "C" {
+void _AVX_MNNPackedSparseMatMulEpx4EFMA_ASM(SparseMatMulParas* temp, const float* bias, const size_t* parameter, const float* postParameters);
+void _AVX_MNNPackedSparseMatMulEpx1EFMA_ASM(SparseMatMulParas* temp, const float* bias, const size_t* parameter, const float* postParameters);
+}
+#endif
+
+void _AVX_MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP){
+    *eP = 24;
+    *lP = 1;
+    *hP = 4;
+    // hp is corresponding to sparse block along right matrix colum dimension. in ramdom sparse, it is 1.
+    return;
+}
+
+#define EMULATED_AVX2_FMA(dst, src0, src1) dst = _mm256_add_ps(dst, _mm256_mul_ps(src0, src1));
+
+#define MIN_MAX_VEC(cVec)               \
+    cVec = _mm256_max_ps(cVec, minVec); \
+    cVec = _mm256_min_ps(cVec, maxVec);
+
+#define ONE_H_STORE_E24(cTilePtr)   \
+    cTilePtr[8 * 0] = c0VecPtr[0];  \
+    cTilePtr[8 * 1] = c0VecPtr[1];  \
+    cTilePtr[8 * 2] = c0VecPtr[2];  \
+    cTilePtr[8 * 3] = c0VecPtr[3];  \
+    cTilePtr[8 * 4] = c0VecPtr[4];  \
+    cTilePtr[8 * 5] = c0VecPtr[5];  \
+    cTilePtr[8 * 6] = c0VecPtr[6];  \
+    cTilePtr[8 * 7] = c0VecPtr[7];  \
+                                    \
+    cTilePtr[8 * 8]  = c1VecPtr[0]; \
+    cTilePtr[8 * 9]  = c1VecPtr[1]; \
+    cTilePtr[8 * 10] = c1VecPtr[2]; \
+    cTilePtr[8 * 11] = c1VecPtr[3]; \
+    cTilePtr[8 * 12] = c1VecPtr[4]; \
+    cTilePtr[8 * 13] = c1VecPtr[5]; \
+    cTilePtr[8 * 14] = c1VecPtr[6]; \
+    cTilePtr[8 * 15] = c1VecPtr[7]; \
+                                    \
+    cTilePtr[8 * 16] = c2VecPtr[0]; \
+    cTilePtr[8 * 17] = c2VecPtr[1]; \
+    cTilePtr[8 * 18] = c2VecPtr[2]; \
+    cTilePtr[8 * 19] = c2VecPtr[3]; \
+    cTilePtr[8 * 20] = c2VecPtr[4]; \
+    cTilePtr[8 * 21] = c2VecPtr[5]; \
+    cTilePtr[8 * 22] = c2VecPtr[6]; \
+    cTilePtr[8 * 23] = c2VecPtr[7];
+
+#define TRANSPOSE_4x4_WITH_STORE(rowIdx, offset, cVec0, cVec1, cVec2, cVec3, cTilePtr)     \
+    {                                                                                      \
+        transposeTemp0 = _mm256_extractf128_ps(cVec0, offset);                             \
+        transposeTemp1 = _mm256_extractf128_ps(cVec1, offset);                             \
+        transposeTemp2 = _mm256_extractf128_ps(cVec2, offset);                             \
+        transposeTemp3 = _mm256_extractf128_ps(cVec3, offset);                             \
+        _MM_TRANSPOSE4_PS(transposeTemp0, transposeTemp1, transposeTemp2, transposeTemp3); \
+        _mm_store_ps(cTilePtr + (rowIdx + 0) * unit, transposeTemp0);                      \
+        _mm_store_ps(cTilePtr + (rowIdx + 1) * unit, transposeTemp1);                      \
+        _mm_store_ps(cTilePtr + (rowIdx + 2) * unit, transposeTemp2);                      \
+        _mm_store_ps(cTilePtr + (rowIdx + 3) * unit, transposeTemp3);                      \
+    }
+
+#define TRANSPOSE_4x24_WITH_STORE(cTilePtr, unit)                               \
+    {                                                                           \
+        __m128 transposeTemp0;                                                  \
+        __m128 transposeTemp1;                                                  \
+        __m128 transposeTemp2;                                                  \
+        __m128 transposeTemp3;                                                  \
+        TRANSPOSE_4x4_WITH_STORE(0, 0, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr);   \
+        TRANSPOSE_4x4_WITH_STORE(4, 1, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr);   \
+        TRANSPOSE_4x4_WITH_STORE(8, 0, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr);  \
+        TRANSPOSE_4x4_WITH_STORE(12, 1, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr); \
+        TRANSPOSE_4x4_WITH_STORE(16, 0, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr); \
+        TRANSPOSE_4x4_WITH_STORE(20, 1, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr); \
+    }
+
+#define REMAIN_TRANSPOSE_4x24_WITH_STORE(cTilePtr, unit)                                       \
+    {                                                                                          \
+        __m128 transposeTemp0;                                                                 \
+        __m128 transposeTemp1;                                                                 \
+        __m128 transposeTemp2;                                                                 \
+        __m128 transposeTemp3;                                                                 \
+        int tailE  = eSize % 4;                                                                \
+        int eFull4 = eSize / 4;                                                                \
+        switch (eFull4) {                                                                      \
+            case 5:                                                                            \
+                TRANSPOSE_4x4_WITH_STORE(16, 0, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr);        \
+            case 4:                                                                            \
+                TRANSPOSE_4x4_WITH_STORE(12, 1, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr);        \
+            case 3:                                                                            \
+                TRANSPOSE_4x4_WITH_STORE(8, 0, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr);         \
+            case 2:                                                                            \
+                TRANSPOSE_4x4_WITH_STORE(4, 1, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr);          \
+            case 1:                                                                            \
+                TRANSPOSE_4x4_WITH_STORE(0, 0, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr);          \
+            default:                                                                           \
+                break;                                                                         \
+        }                                                                                      \
+        if (tailE) {                                                                           \
+            if (eFull4 == 5) {                                                                 \
+                transposeTemp0 = _mm256_extractf128_ps(c2Vec, 1);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c5Vec, 1);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c8Vec, 1);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c11Vec, 1);                             \
+            } else if (eFull4 == 4) {                                                          \
+                transposeTemp0 = _mm256_extractf128_ps(c2Vec, 0);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c5Vec, 0);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c8Vec, 0);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c11Vec, 0);                             \
+            } else if (eFull4 == 3) {                                                          \
+                transposeTemp0 = _mm256_extractf128_ps(c1Vec, 1);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c4Vec, 1);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c7Vec, 1);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c10Vec, 1);                             \
+            } else if (eFull4 == 2) {                                                          \
+                transposeTemp0 = _mm256_extractf128_ps(c1Vec, 0);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c4Vec, 0);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c7Vec, 0);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c10Vec, 0);                             \
+            } else if (eFull4 == 1) {                                                          \
+                transposeTemp0 = _mm256_extractf128_ps(c0Vec, 1);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c3Vec, 1);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c6Vec, 1);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c9Vec, 1);                              \
+            }                                                                                  \
+            else{\
+                transposeTemp0 = _mm256_extractf128_ps(c0Vec, 0);                              \
+                transposeTemp1 = _mm256_extractf128_ps(c3Vec, 0);                              \
+                transposeTemp2 = _mm256_extractf128_ps(c6Vec, 0);                              \
+                transposeTemp3 = _mm256_extractf128_ps(c9Vec, 0);                              \
+            }\
+            _MM_TRANSPOSE4_PS(transposeTemp0, transposeTemp1, transposeTemp2, transposeTemp3); \
+            int offset = 4 * eFull4;                                                           \
+            switch (tailE) {                                                                   \
+                case 3:                                                                        \
+                    _mm_storeu_ps(cTilePtr + (offset + 2) * unit, transposeTemp2);             \
+                case 2:                                                                        \
+                    _mm_storeu_ps(cTilePtr + (offset + 1) * unit, transposeTemp1);             \
+                case 1:                                                                        \
+                    _mm_storeu_ps(cTilePtr + (offset + 0) * unit, transposeTemp0);             \
+                default:                                                                       \
+                    break;                                                                     \
+            }                                                                                  \
+        }                                                                                      \
+    }
+
+#define FP32_BYTES      4
+#define AVX2_SPARSE_EP  24
+#define AVX2_SP_BLOCK4  4
+
+void _AVX_MNNPackedSparseMatMulEpx1EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
+                                    const float* postParameters, const float* bias, unsigned int* NNZMap,
+                                    int* dataOffsetMap) {
+    /*
+    mat_a: [eSize/eP, l, eP]
+    mat_c: [h/unit, e, unit]
+    bias: [h, ]
+    parameter[0]: eP * bytes
+    parameter[1]: l
+    parameter[2]: h
+    parameter[3]: h/unit stride, equals to e * unit * sizeof(dataType)
+    parameter[4]: unit
+    eSize: this tile`s real e size, which can be greater or less than eP!
+    postParameters[2]: min_val of output
+    postParameters[3]: max_val of output
+    */
+
+    /*
+    This func performs the sparse matmul with bias add and post process of min/max threshold.
+    The basic process of the dense version of func is:
+    batch_matmul([l, eP], [h/hP, l, hP]) --> [h/hP, eP, hP].
+    However, when mat_b is sparsed encoded, this func changes accordingly.
+    First, divide the whole process into two part, the full hP part and the remain part.
+    The full hP part means, in each iteration, mat_b`s col (or row actually) is processed in hP count,
+    and the non-zero value is hP continous encoded.
+    The remain part means, in each iteration, mat_b`s col (or row actually) is processed in 1 count,
+    and the non-zero value is encoded one by one.
+    (Although this func is specialized for hP = 1)
+
+    ***********************************************
+    Specialization description:
+    1. eP = 24, hP = 1, lP = 1;
+    2. mat_a stores in [eSize/eP, l, eP] format;
+    3. mat_c stores in [h/unit, e, unit] format;
+    4. data type is fixed as float32, which means the bytes = 4;
+    5. unit is fixed as 8;
+    ***********************************************
+
+    Note that, the function reserves the aStride, which is for mat_a that contains more than one l * eP tile.
+    But for now, limit the eSize <= eP!
+    */
+#ifdef MNN_X86_USE_ASM
+   if (eSize == AVX2_SPARSE_EP && parameter[2] % 4 == 0){
+        // use the asm function when eSize == 24 and h == 4x
+        SparseMatMulParas temp = {C, A, B, NNZMap, dataOffsetMap};
+        SparseMatMulParas* tempPtr = &temp;
+        _AVX_MNNPackedSparseMatMulEpx1EFMA_ASM(tempPtr, bias, parameter, postParameters);
+        return;
+    }
+#endif
+    const size_t aStride = parameter[0] / FP32_BYTES;
+    const size_t l       = parameter[1];
+    const size_t h       = parameter[2];
+    const size_t cStride = parameter[3] / FP32_BYTES; // intrinsic do not need the byte stride.
+    const size_t unit    = 8;
+
+    MNN_ASSERT(eSize <= aStride);
+
+    auto minVec = _mm256_broadcast_ss(postParameters + 2);
+    auto maxVec = _mm256_broadcast_ss(postParameters + 3);
+
+    // full [l, eP] X [h/unit, e, unit]
+    for (int matALoopIdx = 0; matALoopIdx < eSize / aStride; matALoopIdx++) {
+        const float* aTilePtrSt  = A + l * aStride * matALoopIdx;
+        const int* aRowOffsetPtr = dataOffsetMap;
+        const float* weightPtr   = B;
+
+        // as this func is specialized for hP = 1,
+        // iteration in h axis is all full hP method.
+        __m256 c0Vec;
+        __m256 c1Vec;
+        __m256 c2Vec;
+        auto c0VecPtr = (float*)&c0Vec;
+        auto c1VecPtr = (float*)&c1Vec;
+        auto c2VecPtr = (float*)&c2Vec;
+
+        for (int hLoopIdx = 0; hLoopIdx < h; hLoopIdx++) {
+            float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            // inittialize mat_c tile with bias if existed.
+            // [eP, hP] bias initialize.
+
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+            } else {
+                c0Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+            }
+
+            for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                aTilePtrSt += aRowOffsetPtr[0];
+                aRowOffsetPtr++;
+                auto a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
+                auto a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
+                auto a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
+                auto b0Vec = _mm256_broadcast_ss(weightPtr);
+                weightPtr++;
+                c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
+                c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
+                c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
+            }
+
+            // min-max post process and store process.
+            MIN_MAX_VEC(c0Vec);
+            MIN_MAX_VEC(c1Vec);
+            MIN_MAX_VEC(c2Vec);
+
+            ONE_H_STORE_E24(cTilePtrSt);
+        }
+        NNZMap -= h;
+    }
+
+    // remained [l, eSize%eP] X [h/unit, e, unit]
+    A += (eSize / aStride) * aStride * l;
+    C += (eSize / aStride) * aStride * unit;
+    eSize = eSize % aStride; // eSize % 24
+
+    // remained eSize part
+    if (eSize) {
+        // as this func is specialized for hP = 1,
+        // iteration in h axis is all full hP method.
+        __m256 c0Vec;
+        __m256 c1Vec;
+        __m256 c2Vec;
+        auto c0VecPtr  = (float*)&c0Vec;
+        auto c1VecPtr  = (float*)&c1Vec;
+        auto c2VecPtr  = (float*)&c2Vec;
+        for (int hLoopIdx = 0; hLoopIdx < h; hLoopIdx++) {
+            float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            // inittialize mat_c tile with bias if existed.
+            // [eP, hP] bias initialize.
+
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+            } else {
+                c0Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+            }
+
+            for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                A += dataOffsetMap[0];
+                dataOffsetMap++;
+                auto a0Vec      = _mm256_loadu_ps(A + 0);
+                auto a1Vec      = _mm256_loadu_ps(A + 8);
+                auto a2Vec      = _mm256_loadu_ps(A + 16);
+                auto b0Vec = _mm256_broadcast_ss(B);
+                B++;
+                c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
+                c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
+                c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
+            }
+
+            // min-max post process and store process.
+            MIN_MAX_VEC(c0Vec);
+            MIN_MAX_VEC(c1Vec);
+            MIN_MAX_VEC(c2Vec);
+
+            auto CStorePtr = cTilePtrSt;
+            auto cxVecPtr  = c0VecPtr;
+            if (eSize >= 8) {
+                CStorePtr[8 * 0] = cxVecPtr[0];
+                CStorePtr[8 * 1] = cxVecPtr[1];
+                CStorePtr[8 * 2] = cxVecPtr[2];
+                CStorePtr[8 * 3] = cxVecPtr[3];
+                CStorePtr[8 * 4] = cxVecPtr[4];
+                CStorePtr[8 * 5] = cxVecPtr[5];
+                CStorePtr[8 * 6] = cxVecPtr[6];
+                CStorePtr[8 * 7] = cxVecPtr[7];
+                CStorePtr += 8 * unit;
+                cxVecPtr = c1VecPtr;
+            }
+            if (eSize >= 16){
+                CStorePtr[8 * 0] = cxVecPtr[0];
+                CStorePtr[8 * 1] = cxVecPtr[1];
+                CStorePtr[8 * 2] = cxVecPtr[2];
+                CStorePtr[8 * 3] = cxVecPtr[3];
+                CStorePtr[8 * 4] = cxVecPtr[4];
+                CStorePtr[8 * 5] = cxVecPtr[5];
+                CStorePtr[8 * 6] = cxVecPtr[6];
+                CStorePtr[8 * 7] = cxVecPtr[7];
+                CStorePtr += 8 * unit;
+                cxVecPtr = c2VecPtr;
+            }
+            for (int i = 0; i < eSize % 8; i++) {
+                CStorePtr[8 * i] = cxVecPtr[i];
+            }
+        }
+        NNZMap -= h;
+    }
+    return;
+}
+
+void _AVX_MNNPackedSparseMatMulEpx4EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
+                                    const float* postParameters, const float* bias, unsigned int* NNZMap,
+                                    int* dataOffsetMap) {
+    /*
+    mat_a: [eSize/eP, l, eP]
+    mat_c: [h/unit, e, unit]
+    bias: [h, ]
+    parameter[0]: eP * bytes
+    parameter[1]: l
+    parameter[2]: h
+    parameter[3]: h/unit stride, equals to e * unit * sizeof(dataType)
+    parameter[4]: unit
+    eSize: this tile`s real e size, which can be greater or less than eP!
+    postParameters[2]: min_val of output
+    postParameters[3]: max_val of output
+    */
+
+    /*
+    This func performs the sparse matmul with bias add and post process of min/max threshold.
+    The basic process of the dense version of func is:
+    batch_matmul([l, eP], [h/hP, l, hP]) --> [h/hP, eP, hP].
+    However, when mat_b is sparsed encoded, this func changes accordingly.
+    First, divide the whole process into two part, the full hP part and the remain part.
+    The full hP part means, in each iteration, mat_b`s col (or row actually) is processed in hP count,
+    and the non-zero value is hP continous encoded.
+    The remain part means, in each iteration, mat_b`s col (or row actually) is processed in 1 count,
+    and the non-zero value is encoded one by one.
+
+    ***********************************************
+    Specialization description:
+    1. eP = 24, hP = 4, lP = 1;
+    2. mat_a stores in [eSize/eP, l, eP] format;
+    3. mat_c stores in [h/unit, e, unit] format;
+    4. data type is fixed as float32, which means the bytes = 4;
+    5. unit is fixed as 8;
+    ***********************************************
+
+    Note that, the function reserves the aStride, which is for mat_a that contains more than one l * eP tile.
+    But for now, limit the eSize <= eP!
+    */
+#define ONE_LP_ACT_E24(cVecFirst, cVecSecond, cVecThird)       \
+    b0Vec = _mm256_broadcast_ss(weightPtr);                    \
+    weightPtr++;                                               \
+    cVecFirst  = EMULATED_AVX2_FMA(cVecFirst, a0Vec, b0Vec);  \
+    cVecSecond = EMULATED_AVX2_FMA(cVecSecond, a1Vec, b0Vec); \
+    cVecThird  = EMULATED_AVX2_FMA(cVecThird, a2Vec, b0Vec);
+
+#define REMAIN_E_ONE_LP_ACT_E24(cVecFirst, cVecSecond, cVecThird) \
+    b0Vec = _mm256_broadcast_ss(B);                               \
+    B++;                                                          \
+    cVecFirst  = EMULATED_AVX2_FMA(cVecFirst, a0Vec, b0Vec);     \
+    cVecSecond = EMULATED_AVX2_FMA(cVecSecond, a1Vec, b0Vec);    \
+    cVecThird  = EMULATED_AVX2_FMA(cVecThird, a2Vec, b0Vec);
+
+#ifdef MNN_X86_USE_ASM
+   if (eSize == AVX2_SPARSE_EP && parameter[2] % 4 == 0){
+        // use the asm function when eSize == eP(24) and h == 4x
+        SparseMatMulParas temp = {C, A, B, NNZMap, dataOffsetMap};
+        SparseMatMulParas* tempPtr = &temp;
+        _AVX_MNNPackedSparseMatMulEpx4EFMA_ASM(tempPtr, bias, parameter, postParameters);
+        return;
+    }
+#endif
+    const size_t aStride = parameter[0] / FP32_BYTES; // intrinsic do not need the byte stride.
+    const size_t l       = parameter[1];
+    const size_t h       = parameter[2];
+    const size_t cStride = parameter[3] / FP32_BYTES; // intrinsic do not need the byte stride.
+    const size_t unit    = 8;
+
+    MNN_ASSERT(eSize <= aStride);
+
+    const float minVal = postParameters[2];
+    const float maxVal = postParameters[3];
+    const int fullHCnt = h / AVX2_SP_BLOCK4 * AVX2_SP_BLOCK4;
+
+    // full [l, eP] X [h/unit, e, unit]
+    for (int matALoopIdx = 0; matALoopIdx < eSize / aStride; matALoopIdx++) {
+        const float* aTilePtrSt  = A + l * aStride * matALoopIdx;
+        const int* aRowOffsetPtr = dataOffsetMap;
+        const float* weightPtr   = B;
+        int hLoopIdx             = 0;
+
+        // full hP method!
+        for (; hLoopIdx < fullHCnt; hLoopIdx += AVX2_SP_BLOCK4) {
+            float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            __m256 c0Vec;
+            __m256 c1Vec;
+            __m256 c2Vec;
+            __m256 c3Vec;
+            __m256 c4Vec;
+            __m256 c5Vec;
+            __m256 c6Vec;
+            __m256 c7Vec;
+            __m256 c8Vec;
+            __m256 c9Vec;
+            __m256 c10Vec;
+            __m256 c11Vec;
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c3Vec = _mm256_broadcast_ss(bias + hLoopIdx + 1);
+                c6Vec = _mm256_broadcast_ss(bias + hLoopIdx + 2);
+                c9Vec = _mm256_broadcast_ss(bias + hLoopIdx + 3);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+                c4Vec = c3Vec;
+                c5Vec = c3Vec;
+                c7Vec = c6Vec;
+                c8Vec = c6Vec;
+                c10Vec = c9Vec;
+                c11Vec = c9Vec;
+
+            } else {
+                // [intrinsic bug] zeroall will not work after the first iteration!
+                c0Vec = _mm256_setzero_ps();
+                c3Vec = _mm256_setzero_ps();
+                c6Vec = _mm256_setzero_ps();
+                c9Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+                c4Vec = _mm256_setzero_ps();
+                c5Vec = _mm256_setzero_ps();
+                c7Vec = _mm256_setzero_ps();
+                c8Vec = _mm256_setzero_ps();
+                c10Vec = _mm256_setzero_ps();
+                c11Vec = _mm256_setzero_ps();
+            }
+
+            {
+                __m256 a0Vec;
+                __m256 a1Vec;
+                __m256 a2Vec;
+                __m256 b0Vec;
+
+                for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                    //printf("aRowOffset: %d\t", *aRowOffsetPtr);
+                    aTilePtrSt += *aRowOffsetPtr;
+                    aRowOffsetPtr++;
+                    a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
+                    a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
+                    a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
+                    ONE_LP_ACT_E24(c0Vec, c1Vec, c2Vec);
+                    ONE_LP_ACT_E24(c3Vec, c4Vec, c5Vec);
+                    ONE_LP_ACT_E24(c6Vec, c7Vec, c8Vec);
+                    ONE_LP_ACT_E24(c9Vec, c10Vec, c11Vec);
+                }
+            }
+            {
+                auto minVec = _mm256_set1_ps(minVal);
+                auto maxVec = _mm256_set1_ps(maxVal);
+
+                MIN_MAX_VEC(c0Vec);
+                MIN_MAX_VEC(c1Vec);
+                MIN_MAX_VEC(c2Vec);
+                MIN_MAX_VEC(c3Vec);
+                MIN_MAX_VEC(c4Vec);
+                MIN_MAX_VEC(c5Vec);
+                MIN_MAX_VEC(c6Vec);
+                MIN_MAX_VEC(c7Vec);
+                MIN_MAX_VEC(c8Vec);
+                MIN_MAX_VEC(c9Vec);
+                MIN_MAX_VEC(c10Vec);
+                MIN_MAX_VEC(c11Vec);
+            }
+            TRANSPOSE_4x24_WITH_STORE(cTilePtrSt, unit);
+        }
+
+        // remain hP method!
+        __m256 c0Vec;
+        __m256 c1Vec;
+        __m256 c2Vec;
+        auto minVec   = _mm256_set1_ps(minVal);
+        auto maxVec   = _mm256_set1_ps(maxVal);
+        auto c0VecPtr = (float*)&c0Vec;
+        auto c1VecPtr = (float*)&c1Vec;
+        auto c2VecPtr = (float*)&c2Vec;
+
+        for (; hLoopIdx < h; hLoopIdx++) {
+            float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            // inittialize mat_c tile with bias if existed.
+            // [eP, hP] bias initialize.
+
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+            } else {
+                c0Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+            }
+
+            for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                aTilePtrSt += aRowOffsetPtr[0];
+                aRowOffsetPtr++;
+                auto a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
+                auto a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
+                auto a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
+                auto b0Vec = _mm256_broadcast_ss(weightPtr);
+                weightPtr++;
+                c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
+                c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
+                c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
+            }
+
+            // min-max post process and store process.
+            MIN_MAX_VEC(c0Vec);
+            MIN_MAX_VEC(c1Vec);
+            MIN_MAX_VEC(c2Vec);
+
+            ONE_H_STORE_E24(cTilePtrSt);
+        }
+        NNZMap -= fullHCnt / AVX2_SP_BLOCK4 + h - fullHCnt;
+    }
+
+    // remained [l, eSize%eP] X [h/unit, e, unit]
+    A += (eSize / aStride) * aStride * l;
+    C += (eSize / aStride) * aStride * unit;
+    eSize = eSize % aStride; // eSize % 24
+
+    // remained eSize part
+    if (eSize) {
+        int hLoopIdx   = 0;
+        for (; hLoopIdx < fullHCnt; hLoopIdx += AVX2_SP_BLOCK4) {
+            float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            __m256 c0Vec;
+            __m256 c1Vec;
+            __m256 c2Vec;
+            __m256 c3Vec;
+            __m256 c4Vec;
+            __m256 c5Vec;
+            __m256 c6Vec;
+            __m256 c7Vec;
+            __m256 c8Vec;
+            __m256 c9Vec;
+            __m256 c10Vec;
+            __m256 c11Vec;
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c3Vec = _mm256_broadcast_ss(bias + hLoopIdx + 1);
+                c6Vec = _mm256_broadcast_ss(bias + hLoopIdx + 2);
+                c9Vec = _mm256_broadcast_ss(bias + hLoopIdx + 3);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+                c4Vec = c3Vec;
+                c5Vec = c3Vec;
+                c7Vec = c6Vec;
+                c8Vec = c6Vec;
+                c10Vec = c9Vec;
+                c11Vec = c9Vec;
+
+            } else {
+                // [intrinsic bug] zeroall will not work after the first iteration!
+                c0Vec = _mm256_setzero_ps();
+                c3Vec = _mm256_setzero_ps();
+                c6Vec = _mm256_setzero_ps();
+                c9Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+                c4Vec = _mm256_setzero_ps();
+                c5Vec = _mm256_setzero_ps();
+                c7Vec = _mm256_setzero_ps();
+                c8Vec = _mm256_setzero_ps();
+                c10Vec = _mm256_setzero_ps();
+                c11Vec = _mm256_setzero_ps();
+            }
+
+            {
+                __m256 a0Vec;
+                __m256 a1Vec;
+                __m256 a2Vec;
+                __m256 b0Vec;
+                
+                for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                    A += *dataOffsetMap;
+                    dataOffsetMap++;
+                    a0Vec = _mm256_loadu_ps(A + 0);
+                    a1Vec = _mm256_loadu_ps(A + 8);
+                    a2Vec = _mm256_loadu_ps(A + 16);
+
+                    REMAIN_E_ONE_LP_ACT_E24(c0Vec, c1Vec, c2Vec);
+                    REMAIN_E_ONE_LP_ACT_E24(c3Vec, c4Vec, c5Vec);
+                    REMAIN_E_ONE_LP_ACT_E24(c6Vec, c7Vec, c8Vec);
+                    REMAIN_E_ONE_LP_ACT_E24(c9Vec, c10Vec, c11Vec);
+                }
+            }
+            {
+
+                auto minVec = _mm256_set1_ps(minVal);
+                auto maxVec = _mm256_set1_ps(maxVal);
+                MIN_MAX_VEC(c0Vec);
+                MIN_MAX_VEC(c1Vec);
+                MIN_MAX_VEC(c2Vec);
+                MIN_MAX_VEC(c3Vec);
+                MIN_MAX_VEC(c4Vec);
+                MIN_MAX_VEC(c5Vec);
+                MIN_MAX_VEC(c6Vec);
+                MIN_MAX_VEC(c7Vec);
+                MIN_MAX_VEC(c8Vec);
+                MIN_MAX_VEC(c9Vec);
+                MIN_MAX_VEC(c10Vec);
+                MIN_MAX_VEC(c11Vec);
+            }
+            REMAIN_TRANSPOSE_4x24_WITH_STORE(cTilePtrSt, unit);
+        }
+
+        // remained h part
+        __m256 c0Vec;
+        __m256 c1Vec;
+        __m256 c2Vec;
+        auto c0VecPtr = (float*)&c0Vec;
+        auto c1VecPtr = (float*)&c1Vec;
+        auto c2VecPtr = (float*)&c2Vec;
+        auto minVec   = _mm256_set1_ps(minVal);
+        auto maxVec   = _mm256_set1_ps(maxVal);
+
+        for (; hLoopIdx < h; hLoopIdx++) {
+            float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
+            size_t nonZeroCnt = *NNZMap;
+            NNZMap++;
+
+            // inittialize mat_c tile with bias if existed.
+            // [eP, hP] bias initialize.
+
+            if (bias != nullptr) {
+                c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
+                c1Vec = c0Vec;
+                c2Vec = c0Vec;
+            } else {
+                c0Vec = _mm256_setzero_ps();
+                c1Vec = _mm256_setzero_ps();
+                c2Vec = _mm256_setzero_ps();
+            }
+            __m256 a0Vec;
+            __m256 a1Vec;
+            __m256 a2Vec;
+            for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
+                A += *dataOffsetMap;
+                dataOffsetMap++;
+                a0Vec      = _mm256_loadu_ps(A + 0);
+                a1Vec      = _mm256_loadu_ps(A + 8);
+                a2Vec      = _mm256_loadu_ps(A + 16);
+
+                auto b0Vec = _mm256_broadcast_ss(B);
+                B++;
+                EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
+                EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
+                EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
+            }
+
+            // min-max post process and store process.
+            MIN_MAX_VEC(c0Vec);
+            MIN_MAX_VEC(c1Vec);
+            MIN_MAX_VEC(c2Vec);
+
+            auto CStorePtr = cTilePtrSt;
+            auto cxVecPtr  = c0VecPtr;
+            if (eSize >= 8) {
+                CStorePtr[8 * 0] = cxVecPtr[0];
+                CStorePtr[8 * 1] = cxVecPtr[1];
+                CStorePtr[8 * 2] = cxVecPtr[2];
+                CStorePtr[8 * 3] = cxVecPtr[3];
+                CStorePtr[8 * 4] = cxVecPtr[4];
+                CStorePtr[8 * 5] = cxVecPtr[5];
+                CStorePtr[8 * 6] = cxVecPtr[6];
+                CStorePtr[8 * 7] = cxVecPtr[7];
+                CStorePtr += 8 * unit;
+                cxVecPtr = c1VecPtr;
+            }
+            if (eSize >= 16){
+                CStorePtr[8 * 0] = cxVecPtr[0];
+                CStorePtr[8 * 1] = cxVecPtr[1];
+                CStorePtr[8 * 2] = cxVecPtr[2];
+                CStorePtr[8 * 3] = cxVecPtr[3];
+                CStorePtr[8 * 4] = cxVecPtr[4];
+                CStorePtr[8 * 5] = cxVecPtr[5];
+                CStorePtr[8 * 6] = cxVecPtr[6];
+                CStorePtr[8 * 7] = cxVecPtr[7];
+                CStorePtr += 8 * unit;
+                cxVecPtr = c2VecPtr;
+            }
+            for (int i = 0; i < eSize % 8; i++) {
+                CStorePtr[8 * i] = cxVecPtr[i];
+            }
+        }
+        NNZMap -= h;
+    }
+    return;
+#undef REMAIN_E_ONE_LP_ACT_E24
+#undef ONE_LP_ACT_E24
+}
+
+#undef AVX2_SP_BLOCK4
+#undef AVX2_SPARSE_EP
+#undef FP32_BYTES
+#undef EMULATED_AVX2_FMA
+#undef MIN_MAX_VEC
+#undef ONE_H_STORE_E24
+#undef TRANSPOSE_4x4_WITH_STORE
+#undef TRANSPOSE_4x24_WITH_STORE
+#undef REMAIN_TRANSPOSE_4x24_WITH_STORE

+ 0 - 46
source/backend/cpu/x86_x64/avx/MNNMatrixAdd.cpp

@@ -1,46 +0,0 @@
-//
-//  MNNMatrixAdd.cpp
-//  MNN
-//
-//  Created by MNN on 2019/08/25.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#include "FunctionSummary.hpp"
-void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
-                       size_t bStride, size_t height) {
-    for (int y = 0; y < height; ++y) {
-        auto a = A + aStride * y;
-        auto b = B + bStride * y;
-        auto c = C + cStride * y;
-        for (int x = 0; x < widthC4; ++x) {
-            _mm256_storeu_ps(c + 8 * x, _mm256_add_ps(_mm256_loadu_ps(b + 8 * x), _mm256_loadu_ps(a + 8 * x)));
-        }
-    }
-}
-
-void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub) {
-    const int unit = 8;
-    for (int y=0; y<hSub; ++y) {
-        auto c11Y = c11 + y * cStride;
-        auto c12Y = c12 + y * cStride;
-        auto c22Y = c22 + y * cStride;
-        auto c21Y = c21 + y * cStride;
-        auto xY = xAddr + y * eSub * unit;
-        for (int x=0; x<eSub; ++x) {
-            auto xv = _mm256_loadu_ps(xY + unit*x);
-            auto c21v = _mm256_loadu_ps(c21Y + unit*x);
-            auto c11v = _mm256_loadu_ps(c11Y + unit*x);
-            auto c22v = _mm256_loadu_ps(c22Y + unit*x);
-            auto c12v = _mm256_loadu_ps(c12Y + unit*x);
-            c12v = _mm256_add_ps(c12v, xv);
-            c21v = _mm256_add_ps(c12v, c21v);
-            c12v = _mm256_add_ps(c22v, c12v);
-            c22v = _mm256_add_ps(c22v, c21v);
-            c12v = _mm256_add_ps(c11v, c12v);
-            _mm256_storeu_ps(c12Y + unit*x, c12v);
-            _mm256_storeu_ps(c22Y + unit*x, c22v);
-            _mm256_storeu_ps(c21Y + unit*x, c21v);
-        }
-    }
-}

+ 0 - 21
source/backend/cpu/x86_x64/avx/MNNMatrixSub.cpp

@@ -1,21 +0,0 @@
-//
-//  MNNMatrixSub.cpp
-//  MNN
-//
-//  Created by MNN on 2019/08/25.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#include "FunctionSummary.hpp"
-
-void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
-                       size_t bStride, size_t height) {
-    for (int y = 0; y < height; ++y) {
-        auto a = A + aStride * y;
-        auto b = B + bStride * y;
-        auto c = C + cStride * y;
-        for (int x = 0; x < widthC4; ++x) {
-            _mm256_storeu_ps(c + 8 * x, _mm256_sub_ps(_mm256_loadu_ps(a + 8 * x), _mm256_loadu_ps(b + 8 * x)));
-        }
-    }
-}

+ 265 - 0
source/backend/cpu/x86_x64/avx/MathFunctions.cpp

@@ -0,0 +1,265 @@
+//
+//  MathFunctions.cpp
+//  MNN
+//
+//  Created by MNN on b'2021/07/05'.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#include "FunctionSummary.hpp"
+#include <math.h>
+
+void _AVX_MNNGelu(float *dst, const float *src, size_t size) {
+    auto var1 = _mm256_set1_ps(0.044715f);
+    auto var2 = _mm256_set1_ps(0.79788458f);
+    auto var3 = _mm256_set1_ps(378.f);
+    auto var4 = _mm256_set1_ps(17325.f);
+    auto var5 = _mm256_set1_ps(135135.f);
+    auto var6 = _mm256_set1_ps(28.f);
+    auto var7 = _mm256_set1_ps(3150.f);
+    auto var8 = _mm256_set1_ps(62370.f);
+    auto var9 = _mm256_set1_ps(135135.f);
+    auto var10 = _mm256_set1_ps(0.5);
+    auto varOne = _mm256_set1_ps(1.f);
+    auto varNegOne = _mm256_set1_ps(-1.f);
+    for (int i = 0; i < size; i++) {
+        auto x = _mm256_loadu_ps(src + i * 8);
+        auto y = _mm256_mul_ps(x, x);
+        y = _mm256_mul_ps(y, x);
+        y = _mm256_mul_ps(y, var1);
+        y = _mm256_add_ps(y, x);
+        y = _mm256_mul_ps(y, var2);
+        // y = tanh(y)
+        {
+            auto y2 = _mm256_mul_ps(y, y);
+            auto w = _mm256_add_ps(y2, var3);
+            w = _mm256_mul_ps(w, y2);
+            w = _mm256_add_ps(w, var4);
+            w = _mm256_mul_ps(w, y2);
+            w = _mm256_add_ps(w, var5);
+            w = _mm256_mul_ps(w, y);
+            auto z = _mm256_mul_ps(y2, var6);
+            z = _mm256_add_ps(z, var7);
+            z = _mm256_mul_ps(z, y2);
+            z = _mm256_add_ps(z, var8);
+            z = _mm256_mul_ps(z, y2);
+            z = _mm256_add_ps(z, var9);
+            z = _mm256_div_ps(w, z);
+            z = _mm256_max_ps(z, varNegOne);
+            y = _mm256_min_ps(z, varOne);
+        }
+        y = _mm256_add_ps(y, varOne);
+        y = _mm256_mul_ps(y, x);
+        y = _mm256_mul_ps(y, var10);
+        _mm256_storeu_ps(dst + i * 8, y);
+    }
+}
+void _AVX_MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) {
+    auto count = countC8;
+    auto A     = _mm256_broadcast_ss(offset + 0);
+    auto B     = _mm256_broadcast_ss(offset + 1);
+    auto p0    = _mm256_set1_ps(parameters[0]);
+    auto p1    = _mm256_set1_ps(parameters[1]);
+    auto p2    = _mm256_set1_ps(parameters[2]);
+    auto p3    = _mm256_set1_ps(parameters[3]);
+    auto p4    = _mm256_set1_ps(parameters[4]);
+    auto p5    = _mm256_set1_ps(parameters[5]);
+    auto p6    = _mm256_set1_ps(parameters[6]);
+    auto p7    = _mm256_set1_ps(parameters[7]);
+    auto xMax  = _mm256_set1_ps(87);
+    auto xMin  = _mm256_set1_ps(-87);
+    auto basic = _mm256_set1_epi32(1 << 23);
+    auto temp127 = _mm256_set1_epi32(127);
+    auto negZero = _mm256_set1_ps(-0.f);
+    for (int i = 0; i < count; ++i) {
+        auto x            = _mm256_mul_ps(_mm256_loadu_ps(source + i * 8), A);
+        x                 = _mm256_max_ps(x, xMin);
+        x                 = _mm256_min_ps(x, xMax);
+        auto div          = _mm256_mul_ps(x, p1);
+        auto divInt       = _mm256_cvtps_epi32(div);
+        div               = _mm256_cvtepi32_ps(divInt);
+        auto div2         = _mm256_add_epi32(divInt, temp127);
+        div2 = _mm256_mullo_epi32(div2, basic);
+        auto expBasic  = _mm256_castsi256_ps(div2);
+        auto xReamin   = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
+        auto t         = xReamin;
+        auto c0        = _mm256_mul_ps(p7, t);
+        auto c1        = _mm256_add_ps(c0, p6);
+        auto c2        = _mm256_mul_ps(c1, t);
+        auto c3        = _mm256_add_ps(c2, p5);
+        auto c4        = _mm256_mul_ps(c3, t);
+        auto c5        = _mm256_add_ps(c4, p4);
+        auto c6        = _mm256_mul_ps(c5, t);
+        auto c7        = _mm256_add_ps(c6, p3);
+        auto c8        = _mm256_mul_ps(c7, t);
+        auto c9        = _mm256_add_ps(c8, p2);
+        auto expRemain = c9;
+        _mm256_storeu_ps(dest + 8 * i, _mm256_add_ps(_mm256_mul_ps(expBasic, expRemain), B));
+    }
+}
+
+
+void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) {
+    float tmpfloat8[8];
+    int count  = size / 8;
+    int remain = count * 8;
+    // step 1: get maxValue
+    float maxValue = 0.f;
+    if (count > 0) {
+        auto maxVal = _mm256_loadu_ps(source);
+        for (int i = 1; i < count; i++) {
+            maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8));
+        }
+        _mm256_storeu_ps(tmpfloat8, maxVal);
+        maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1];
+        for (int i = 2; i < 8; i++) {
+            maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i];
+        }
+    }
+    for (int i = remain; i < size; i++) {
+        maxValue = maxValue > source[i] ? maxValue : source[i];
+    }
+
+    // step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
+    float sumValue = 0.f;
+    if (count > 0) {
+        auto sumVal = _mm256_set1_ps(0.f);
+        auto p0    = _mm256_set1_ps(0.6931471805599453);
+        auto p1    = _mm256_set1_ps(1.4426950408889634);
+        auto p2    = _mm256_set1_ps(1.f);
+        auto p3    = _mm256_set1_ps(1.f);
+        auto p4    = _mm256_set1_ps(0.5);
+        auto p5    = _mm256_set1_ps(0.1666666666666666);
+        auto p6    = _mm256_set1_ps(0.041666666666666664);
+        auto p7    = _mm256_set1_ps(0.008333333333333333);
+        auto xMax  = _mm256_set1_ps(87);
+        auto xMin  = _mm256_set1_ps(-87);
+        auto basic = _mm256_set1_epi32(1 << 23);
+        auto temp127 = _mm256_set1_epi32(127);
+        for (int i = 0; i < count; ++i) {
+            auto x            = _mm256_sub_ps(_mm256_loadu_ps(source + i * 8), _mm256_set1_ps(maxValue));
+            x                 = _mm256_max_ps(x, xMin);
+            x                 = _mm256_min_ps(x, xMax);
+            auto div          = _mm256_mul_ps(x, p1);
+            auto divInt       = _mm256_cvtps_epi32(div);
+            div               = _mm256_cvtepi32_ps(divInt);
+            auto div2         = _mm256_add_epi32(divInt, temp127);
+            div2 = _mm256_mullo_epi32(div2, basic);
+            auto expBasic  = _mm256_castsi256_ps(div2);
+            auto xReamin   = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
+            auto t         = xReamin;
+            auto c0        = _mm256_mul_ps(p7, t);
+            auto c1        = _mm256_add_ps(c0, p6);
+            auto c2        = _mm256_mul_ps(c1, t);
+            auto c3        = _mm256_add_ps(c2, p5);
+            auto c4        = _mm256_mul_ps(c3, t);
+            auto c5        = _mm256_add_ps(c4, p4);
+            auto c6        = _mm256_mul_ps(c5, t);
+            auto c7        = _mm256_add_ps(c6, p3);
+            auto c8        = _mm256_mul_ps(c7, t);
+            auto c9        = _mm256_add_ps(c8, p2);
+            auto expRemain = c9;
+            auto expRes    = _mm256_mul_ps(expBasic, expRemain);
+            sumVal         = _mm256_add_ps(expRes, sumVal);
+            _mm256_storeu_ps(dest + 8 * i, expRes);
+        }
+        _mm256_storeu_ps(tmpfloat8, sumVal);
+        for (int i = 0; i < 8; i++) {
+            sumValue += tmpfloat8[i];
+        }
+    }
+    auto param = 0.6931471805599453;
+    float xLimit = 87;
+    for (int i = remain; i < size; i++) {
+        auto x         = source[i] - maxValue;
+        x = x > -xLimit ? x : -xLimit;
+        x = x < xLimit ? x : xLimit;
+
+        int div        = (x / param);
+        int div2       = (div + 127) << 23;
+        auto xReamin   = x - div * param;
+        float expBasic = *(float*)(&div2);
+
+        auto t         = xReamin;
+        auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
+        dest[i]  = expBasic * expRemain;
+        sumValue += dest[i];
+    }
+    // step 3: get x / sum and store
+    for (int i = 0; i < count; ++i) {
+        // using  1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
+        auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i));
+        auto y = _mm256_set1_ps(sumValue);
+        auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y));
+        _mm256_storeu_ps(dest + 8 * i, z);
+    }
+    sumValue = 1.f / sumValue;
+    for (int i = remain; i < size; i++) {
+        dest[i] *= sumValue;
+    }
+}
+
+void _AVX_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size) {
+    float tmpfloat8[8];
+    int count  = size / 8;
+    int remain = count * 8;
+    // step 1: get sum
+    float sum = 0.f;
+    if (count > 0) {
+        auto sumVal = _mm256_set1_ps(0.f);
+        for (int i = 0; i < count; i++) {
+            sumVal = _mm256_add_ps(sumVal, _mm256_loadu_ps(src + i * 8));
+        }
+        _mm256_storeu_ps(tmpfloat8, sumVal);
+        for (int i = 0; i < 8; i++) {
+            sum += tmpfloat8[i];
+        }
+    }
+    for (int i = remain; i < size; i++) {
+        sum += src[i];
+    }
+    // step 2: get square_sum
+    float mean = sum / size;
+    float square_sum = 0.f;
+    auto meanVal = _mm256_set1_ps(mean);
+    if (count > 0) {
+        auto sumVal = _mm256_set1_ps(0.f);
+        for (int i = 0; i < count; i++) {
+            auto x = _mm256_sub_ps(_mm256_loadu_ps(src + i * 8), meanVal);
+            sumVal = _mm256_add_ps(sumVal, _mm256_mul_ps(x, x));
+        }
+        _mm256_storeu_ps(tmpfloat8, sumVal);
+        for (int i = 0; i < 8; i++) {
+            square_sum += tmpfloat8[i];
+        }
+    }
+    for (int i = remain; i < size; i++) {
+        float x = (src[i] - mean);
+        square_sum += x * x;
+    }
+    // step 3: get result
+    float variable = square_sum / size;
+    variable = 1.f / sqrt(variable + epsilon);
+    auto variableVal = _mm256_set1_ps(variable);
+    if (gamma && beta) {
+        for (int i = 0; i < count; i++) {
+            auto x = _mm256_sub_ps(_mm256_loadu_ps(src + i * 8), meanVal);
+            auto g = _mm256_loadu_ps(gamma + i * 8);
+            auto b = _mm256_loadu_ps(beta + i * 8);
+            auto y = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(x, g), variableVal), b);
+            _mm256_storeu_ps(dst + i * 8, y);
+        }
+        for (int i = remain; i < size; i++) {
+            dst[i] = (src[i] - mean) * gamma[i] * variable + beta[i] ;
+        }
+    } else {
+        for (int i = 0; i < count; i++) {
+            auto x = _mm256_sub_ps(_mm256_loadu_ps(src + i * 8), meanVal);
+            auto y = _mm256_mul_ps(x, variableVal);
+            _mm256_storeu_ps(dst + i * 8, y);
+        }
+        for (int i = remain; i < size; i++) {
+            dst[i] = (src[i] - mean) * variable;
+        }
+    }
+}

+ 569 - 0
source/backend/cpu/x86_x64/avx/PackedFunction.cpp

@@ -0,0 +1,569 @@
+//
+//  PackedFunction.cpp
+//  MNN
+//
+//  Created by MNN on b'2021/07/05'.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#include <float.h>
+#include <string.h>
+#include <algorithm>
+#include <limits>
+#include <vector>
+#include "FunctionSummary.hpp"
+#include "core/Macro.h"
+#include "backend/cpu/CPUPool.hpp"
+#include "backend/cpu/BinaryUtils.hpp"
+#include "Vec8.hpp"
+#define PACK_UNIT 8
+extern "C" {
+void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
+void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
+void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
+void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
+                                       size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
+void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
+                                       size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
+void _AVX_MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
+void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub);
+void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
+                                     size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
+void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter);
+void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu);
+void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter);
+void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
+                       size_t bStride, size_t height);
+void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
+                       size_t bStride, size_t height);
+void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride,
+                                    size_t length, size_t hSub);
+void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
+                                size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
+                                     size_t srcHStep, size_t dstHStep);
+void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters);
+}
+
+
+void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
+    for (int i = 0; i < count; ++i) {
+        auto s = source + i * srcStride;
+        auto d = dest + i * dstStride;
+        _mm256_storeu_ps(d, _mm256_loadu_ps(s));
+    }
+}
+void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
+    for (int i = 0; i < count; ++i) {
+        auto s = source + i * srcStride;
+        auto d = dest + i * dstStride;
+        _mm256_storeu_ps(d, _mm256_add_ps(_mm256_loadu_ps(s), _mm256_loadu_ps(d)));
+    }
+}
+
+void _AVX_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
+    auto zero = _mm_set1_ps(0.0f);
+    auto zero2 = _mm256_set1_ps(0.0f);
+    int sizeC8 = sizeQuad;
+    for (int j = 0; j < depthQuad; j++) {
+        auto slopeZ       = _mm256_loadu_ps(slope + PACK_UNIT * j);
+        const float* srcZ = src + PACK_UNIT * j * sizeQuad;
+        float* dstZ       = dst + PACK_UNIT * j * sizeQuad;
+        for (int i = 0; i < sizeC8; i++) {
+            auto src   = _mm256_loadu_ps(srcZ);
+            auto mask0 = _mm256_cmp_ps(src, zero2, 0x01);
+            auto mask1 = _mm256_cmp_ps(src, zero2, 0x0D);
+            auto other = _mm256_mul_ps(src, slopeZ);
+            _mm256_storeu_ps(dstZ, _mm256_add_ps(_mm256_and_ps(other, mask0), _mm256_and_ps(src, mask1)));
+            srcZ += PACK_UNIT;
+            dstZ += PACK_UNIT;
+        }
+    }
+}
+
+void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
+    auto minF = _mm256_broadcast_ss(parameters + 2);
+    auto maxF = _mm256_broadcast_ss(parameters + 3);
+    for (int y = 0; y < height; ++y) {
+        auto a = A + aStride * y;
+        auto b = B + PACK_UNIT * y;
+        auto bv = _mm256_loadu_ps(b);
+        auto c = C + cStride * y;
+        for (int x = 0; x < width; ++x) {
+            auto av = _mm256_loadu_ps(a);
+            auto cv = _mm256_add_ps(av, bv);
+            cv = _mm256_min_ps(cv, maxF);
+            cv = _mm256_max_ps(cv, minF);
+            _mm256_storeu_ps(c, cv);
+            a += PACK_UNIT;
+            c += PACK_UNIT;
+        }
+    }
+}
+
+void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
+                                  size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
+    int fx, fy;
+    __m256 dstValue = _mm256_setzero_ps();
+    const float* src_z    = src;
+    const float* weight_z = weight;
+    for (fy = 0; fy < fh; ++fy) {
+        const float* src_y    = src_z + fy * dilateY_step;
+        const float* weight_y = weight_z + fy * weight_y_step;
+        for (fx = 0; fx < fw; ++fx) {
+            const float* weight_x = weight_y + PACK_UNIT * fx;
+            const float* src_x    = src_y + fx * dilateX_step;
+            dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
+        }
+    }
+    _mm256_storeu_ps(dst, dstValue);
+}
+
+void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
+                                size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
+                                     size_t srcHStep, size_t dstHStep) {
+    int dx, fx, fy;
+    const int unit = 4;
+    int widthUnit = width / unit;
+    int widthRemain = width - widthUnit * unit;
+    const float* weight_z = weight;
+    for (int y = 0; y < height; ++y) {
+        auto srcY = src + y * srcHStep;
+        auto dstY = dst + y * dstHStep;
+        for (dx = 0; dx < widthUnit; ++dx) {
+            auto dstValue0 = _mm256_setzero_ps();
+            auto dstValue1 = _mm256_setzero_ps();
+            auto dstValue2 = _mm256_setzero_ps();
+            auto dstValue3 = _mm256_setzero_ps();
+            for (fy = 0; fy < fh; ++fy) {
+                const float* src_y    = srcY + fy * dilateY_step;
+                const float* weight_y = weight_z + fy * fw * PACK_UNIT;
+                for (fx = 0; fx < fw; ++fx) {
+                    const float* src_x    = src_y + fx * dilateX_step;
+                    const float* weight_x = weight_y + PACK_UNIT * fx;
+                    auto weightValue = _mm256_loadu_ps(weight_x);
+                    dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * src_w_setup), weightValue));
+                    dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * src_w_setup), weightValue));
+                    dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * src_w_setup), weightValue));
+                    dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * src_w_setup), weightValue));
+                }
+            }
+            _mm256_storeu_ps(dstY + PACK_UNIT * 0, dstValue0);
+            _mm256_storeu_ps(dstY + PACK_UNIT * 1, dstValue1);
+            _mm256_storeu_ps(dstY + PACK_UNIT * 2, dstValue2);
+            _mm256_storeu_ps(dstY + PACK_UNIT * 3, dstValue3);
+            dstY += PACK_UNIT * unit;
+            srcY += unit * src_w_setup;
+        }
+        for (dx = 0; dx < widthRemain; ++dx) {
+            float* dst_x          = dstY + dx * PACK_UNIT;
+            auto dstValue = _mm256_setzero_ps();
+            const float* src_z    = srcY + src_w_setup * dx;
+            const float* weight_z = weight;
+            for (fy = 0; fy < fh; ++fy) {
+                const float* src_y    = src_z + fy * dilateY_step;
+                const float* weight_y = weight_z + fy * fw * PACK_UNIT;
+                for (fx = 0; fx < fw; ++fx) {
+                    const float* weight_x = weight_y + PACK_UNIT * fx;
+                    const float* src_x    = src_y + fx * dilateX_step;
+                    dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
+                }
+            }
+            _mm256_storeu_ps(dst_x, dstValue);
+        }
+    }
+}
+
+static MNNBinaryExecute _AVX2_MNNSelectBinaryFunctionForFloat(int opType) {
+    auto vecF = MNN::selectVector<Vec8, 8>(opType);
+    if (nullptr != vecF) {
+        return vecF;
+    }
+    return MNN::MNNGetCoreFunctions()->MNNSelectBinaryFunctionForFloat(opType);
+}
+
+static void _8BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds) {
+    auto src = (float*)srcO;
+    auto dst = (float*)dstO;
+    for (int i=0; i<size; ++i) {
+        _mm256_storeu_ps(dst, _mm256_loadu_ps(src));
+        src+= (8 * stride);
+        dst+= (8 * ds);
+    }
+}
+static MNNCopyWithStride _selectBlit(int bytesC4) {
+    if (32 == bytesC4) {
+        return _8BitcopyWithStrideC4;
+    }
+    return nullptr;
+}
+
+
+
+void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber,
+                        size_t biasNumber) {
+    for (int z = 0; z < biasNumber; ++z) {
+        float* dstZ         = dst + planeNumber * PACK_UNIT * z;
+        const float* srcZ   = src + planeNumber * PACK_UNIT * z;
+        auto biasZ = Vec8::load(bias + PACK_UNIT * z);
+        auto alphaZ = Vec8::load(alpha + PACK_UNIT * z);
+        for (int p = 0; p < planeNumber; ++p) {
+            float* dstX       = dstZ + PACK_UNIT * p;
+            const float* srcX = srcZ + PACK_UNIT * p;
+            Vec8::save(dstX, (Vec8::load(srcX) * alphaZ) + biasZ);
+        }
+    }
+}
+
+void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
+                                  size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
+    int fx, fy;
+    float* src_z          = src;
+    const float* weight_z = weight;
+    Vec8 dstV             = Vec8::load(dst);
+    for (fy = 0; fy < fh; ++fy) {
+        float* src_y          = src_z + fy * dilateY_step;
+        const float* weight_y = weight_z + fy * weight_y_step;
+        for (fx = 0; fx < fw; ++fx) {
+            Vec8 weight_x = Vec8::load(weight_y + PACK_UNIT * fx);
+            Vec8 src_x    = Vec8::load(src_y + fx * dilateX_step);
+            Vec8::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
+        }
+    }
+}
+void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
+                                  size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
+    int dx;
+    for (dx = 0; dx < width; ++dx) {
+        const float* dst_x = dst + dx * PACK_UNIT;
+        float* src_dx      = src + src_w_setup * dx;
+        _AVX_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * PACK_UNIT, dilateX_step, dilateY_step);
+    }
+}
+
+static __m256 MNNGridSampleLoadSample(int h, int w, const float *buffer, int height, int width, bool padMode) {
+    if (h < 0 || h >= height || w < 0 || w >= width) {
+        if(padMode == true) { //padMode == BorderMode_ZEROS
+            return _mm256_setzero_ps();
+        }
+        // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
+        // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
+        // the leftover reflections degrade to GridSamplePaddingMode_BORDER
+        h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h);
+        w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w);
+    }
+
+    return _mm256_loadu_ps(buffer + h * width * PACK_UNIT + w * PACK_UNIT);
+}
+void _AVX_MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
+    for (auto ow = 0; ow < outW; ++ow) {
+        auto w = cordPtr[2 * ow + 0];
+        auto h = cordPtr[2 * ow + 1];
+        __m256 interp;
+
+        if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
+            int nh = ::floor(h + 0.5f);
+            int nw = ::floor(w + 0.5f);
+            interp = MNNGridSampleLoadSample(nh, nw, inputPtr, inH, inW, padMode);
+        } else { //sampleMode == GridSampleMode_BILINEAR
+            int w0_h = ::floor(h);
+            int w0_w = ::floor(w);
+            int w1_h = ::ceil(h);
+            int w1_w = ::ceil(w);
+            auto oneV = _mm256_set1_ps(1.0f);
+
+            __m256 i00 = MNNGridSampleLoadSample(w0_h, w0_w, inputPtr, inH, inW, padMode);
+            __m256 i01 = MNNGridSampleLoadSample(w0_h, w1_w, inputPtr, inH, inW, padMode);
+            __m256 i10 = MNNGridSampleLoadSample(w1_h, w0_w, inputPtr, inH, inW, padMode);
+            __m256 i11 = MNNGridSampleLoadSample(w1_h, w1_w, inputPtr, inH, inW, padMode);
+            auto f0 = _mm256_set1_ps((float)w1_w - w);
+            auto f1 = _mm256_sub_ps(oneV, f0);
+            auto h0 = _mm256_set1_ps((float)w1_h - h);
+            auto h1 = _mm256_sub_ps(oneV, h0);
+
+            __m256 i0 = _mm256_add_ps(_mm256_mul_ps(i00, f0), _mm256_mul_ps(i01, f1));
+            __m256 i1 = _mm256_add_ps(_mm256_mul_ps(i10, f0), _mm256_mul_ps(i11, f1));
+            interp = _mm256_add_ps(_mm256_mul_ps(i0, h0), _mm256_mul_ps(i1, h1));
+        }
+
+        _mm256_storeu_ps(outputPtr + PACK_UNIT * ow, interp);
+    }
+}
+
+void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
+                       size_t bStride, size_t height) {
+    for (int y = 0; y < height; ++y) {
+        auto a = A + aStride * y;
+        auto b = B + bStride * y;
+        auto c = C + cStride * y;
+        for (int x = 0; x < widthC4; ++x) {
+            _mm256_storeu_ps(c + PACK_UNIT * x, _mm256_add_ps(_mm256_loadu_ps(b + PACK_UNIT * x), _mm256_loadu_ps(a + PACK_UNIT * x)));
+        }
+    }
+}
+
+void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub) {
+    const int unit = PACK_UNIT;
+    for (int y=0; y<hSub; ++y) {
+        auto c11Y = c11 + y * cStride;
+        auto c12Y = c12 + y * cStride;
+        auto c22Y = c22 + y * cStride;
+        auto c21Y = c21 + y * cStride;
+        auto xY = xAddr + y * eSub * unit;
+        for (int x=0; x<eSub; ++x) {
+            auto xv = _mm256_loadu_ps(xY + unit*x);
+            auto c21v = _mm256_loadu_ps(c21Y + unit*x);
+            auto c11v = _mm256_loadu_ps(c11Y + unit*x);
+            auto c22v = _mm256_loadu_ps(c22Y + unit*x);
+            auto c12v = _mm256_loadu_ps(c12Y + unit*x);
+            c12v = _mm256_add_ps(c12v, xv);
+            c21v = _mm256_add_ps(c12v, c21v);
+            c12v = _mm256_add_ps(c22v, c12v);
+            c22v = _mm256_add_ps(c22v, c21v);
+            c12v = _mm256_add_ps(c11v, c12v);
+            _mm256_storeu_ps(c12Y + unit*x, c12v);
+            _mm256_storeu_ps(c22Y + unit*x, c22v);
+            _mm256_storeu_ps(c21Y + unit*x, c21v);
+        }
+    }
+}
+
+void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
+                       size_t bStride, size_t height) {
+    for (int y = 0; y < height; ++y) {
+        auto a = A + aStride * y;
+        auto b = B + bStride * y;
+        auto c = C + cStride * y;
+        for (int x = 0; x < widthC4; ++x) {
+            _mm256_storeu_ps(c + PACK_UNIT * x, _mm256_sub_ps(_mm256_loadu_ps(a + PACK_UNIT * x), _mm256_loadu_ps(b + PACK_UNIT * x)));
+        }
+    }
+}
+
+void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter) {
+    int unit = ow / 2;
+    MNN_ASSERT(cacheLineSize >= 1);
+    auto biasF = Vec8::load(bias);
+    auto minF = Vec8(parameter[2]);
+    auto maxF = Vec8(parameter[3]);
+    auto SRC_TILE_UNIT = 4 * PACK_UNIT;
+    auto DST_TILE_UNIT = 2 * PACK_UNIT;
+    for (int x = 0; x < unit; ++x) {
+        auto offset = SRC_TILE_UNIT * x;
+        int i = 0;
+        Vec8 m0     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
+        Vec8 m1     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
+        Vec8 m2     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
+        Vec8 m3     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 3) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 3);
+
+        for (i = 1; i < cacheLineSize; ++i) {
+            m0 = m0 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
+            m1 = m1 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
+            m2 = m2 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
+            m3 = m3 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 3) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 3);
+        }
+        auto o0 = m0 + m1 + m2 + biasF;
+        auto o1 = m1 - m2 + m3 + biasF;
+        o0 = Vec8::min(maxF, o0);
+        o1 = Vec8::min(maxF, o1);
+        o0 = Vec8::max(minF, o0);
+        o1 = Vec8::max(minF, o1);
+
+        Vec8::save(dest + DST_TILE_UNIT * x + 0 * PACK_UNIT, o0);
+        Vec8::save(dest + DST_TILE_UNIT * x + 1 * PACK_UNIT, o1);
+    }
+    if (unit * 2 < ow) {
+        auto offset = SRC_TILE_UNIT * unit;
+        int i = 0;
+        Vec8 m0     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
+        Vec8 m1     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
+        Vec8 m2     = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
+
+        for (i = 1; i < cacheLineSize; ++i) {
+            m0 = m0 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
+            m1 = m1 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
+            m2 = m2 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
+        }
+        auto o0 = m0 + m1 + m2 + biasF;
+        o0 = Vec8::min(maxF, o0);
+        o0 = Vec8::max(minF, o0);
+        Vec8::save(dest + DST_TILE_UNIT * unit, o0);
+    }
+}
+static void _AVX_MNNConvDwF23SourceTransUnit(const float *source, float *dest, size_t unit) {
+    if (unit <= 0) {
+        return;
+    }
+    Vec8 v0 = Vec8::load(source + PACK_UNIT * 0);
+    Vec8 v1 = Vec8::load(source + PACK_UNIT * 1);
+    Vec8 v2;
+    Vec8 v3;
+    source += 2 * PACK_UNIT;
+
+    for (int x = 0; x < unit; ++x) {
+        v2 = Vec8::load(source + 0 * PACK_UNIT);
+        v3 = Vec8::load(source + 1 * PACK_UNIT);
+        auto m0 = v0 - v2;
+        auto m1 = v1 + v2;
+        auto m2 = v2 - v1;
+        auto m3 = v3 - v1;
+
+        Vec8::save(dest + PACK_UNIT * 0, m0);
+        Vec8::save(dest + PACK_UNIT * 1, m1);
+        Vec8::save(dest + PACK_UNIT * 2, m2);
+        Vec8::save(dest + PACK_UNIT * 3, m3);
+
+        source += (2 * PACK_UNIT);
+        dest += (4 * PACK_UNIT);
+
+        v0 = v2;
+        v1 = v3;
+    }
+}
+
+void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu) {
+    for (int x = 0; x < su; ++x) {
+        auto dstX = dest + 4 * PACK_UNIT * x;
+        auto sx   = x * 2 - (int)pad;
+        auto ex   = sx + 4;
+
+        auto clampSx = std::max(sx, 0);
+        auto clampEx = std::min(ex, (int)iw);
+
+        Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+        for (int i = clampSx; i < clampEx; ++i) {
+            v[i - sx] = Vec8::load(source + 8 * i);
+        }
+        auto m0 = v[0] - v[2];
+        auto m1 = v[1] + v[2];
+        auto m2 = v[2] - v[1];
+        auto m3 = v[3] - v[1];
+
+        Vec8::save(dstX + PACK_UNIT * 0, m0);
+        Vec8::save(dstX + PACK_UNIT * 1, m1);
+        Vec8::save(dstX + PACK_UNIT * 2, m2);
+        Vec8::save(dstX + PACK_UNIT * 3, m3);
+    }
+    _AVX_MNNConvDwF23SourceTransUnit(source + PACK_UNIT * (su * 2 - pad), dest + PACK_UNIT * 4 * su, eu - su);
+
+    for (int x = eu; x < unit; ++x) {
+        auto dstX = dest + PACK_UNIT * 4 * x;
+        auto sx   = x * 2 - (int)pad;
+        auto ex   = sx + 4;
+
+        auto clampSx = std::max(sx, 0);
+        auto clampEx = std::min(ex, (int)iw);
+
+        Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+        for (int i = clampSx; i < clampEx; ++i) {
+            v[i - sx] = Vec8::load(source + PACK_UNIT * i);
+        }
+        auto m0 = v[0] - v[2];
+        auto m1 = v[1] + v[2];
+        auto m2 = v[2] - v[1];
+        auto m3 = v[3] - v[1];
+
+        Vec8::save(dstX + PACK_UNIT * 0, m0);
+        Vec8::save(dstX + PACK_UNIT * 1, m1);
+        Vec8::save(dstX + PACK_UNIT * 2, m2);
+        Vec8::save(dstX + PACK_UNIT * 3, m3);
+    }
+}
+
+void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter) {
+    int unit = ow / 2;
+    auto SRC_TILE_UNIT = 4 * PACK_UNIT;
+    auto DST_TILE_UNIT = 2 * PACK_UNIT;
+
+    auto w00 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 0);
+    auto w01 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 1);
+    auto w02 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 2);
+    auto w03 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 3);
+    auto w10 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 0);
+    auto w11 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 1);
+    auto w12 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 2);
+    auto w13 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 3);
+    auto w20 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 0);
+    auto w21 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 1);
+    auto w22 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 2);
+    auto w23 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 3);
+    auto biasF = Vec8::load(bias);
+    auto minF = Vec8(parameter[2]);
+    auto maxF = Vec8(parameter[3]);
+
+    for (int x = 0; x < unit; ++x) {
+        auto offset = PACK_UNIT * 4 * x;
+        int i = 0;
+        Vec8 m0     = w00 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 0);
+        Vec8 m1     = w01 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 1);
+        Vec8 m2     = w02 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 2);
+        Vec8 m3     = w03 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 3);
+
+        m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 0);
+        m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 1);
+        m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 2);
+        m3 = m3 + w13 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 3);
+
+        m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 0);
+        m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 1);
+        m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 2);
+        m3 = m3 + w23 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 3);
+
+        auto o0 = m0 + m1 + m2 + biasF;
+        auto o1 = m1 - m2 + m3 + biasF;
+        o0 = Vec8::min(maxF, o0);
+        o1 = Vec8::min(maxF, o1);
+        o0 = Vec8::max(minF, o0);
+        o1 = Vec8::max(minF, o1);
+        Vec8::save(dest + DST_TILE_UNIT * x + 0 * PACK_UNIT, o0);
+        Vec8::save(dest + DST_TILE_UNIT * x + 1 * PACK_UNIT, o1);
+    }
+    if (unit * 2 < ow) {
+        auto offset = PACK_UNIT * 4 * unit;
+        Vec8 m0     = w00 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 0);
+        Vec8 m1     = w01 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 1);
+        Vec8 m2     = w02 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 2);
+
+        m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 0);
+        m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 1);
+        m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 2);
+
+        m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 0);
+        m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 1);
+        m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 2);
+        auto o0 = m0 + m1 + m2 + biasF;
+        o0 = Vec8::min(maxF, o0);
+        o0 = Vec8::max(minF, o0);
+        Vec8::save(dest + DST_TILE_UNIT * unit, o0);
+    }
+}
+
+void _AVX_ExtraInit(void* functions) {
+    auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
+    coreFunction->MNNSelectBlitFunction = _selectBlit;
+    coreFunction->MNNPoolingAvg = (decltype(coreFunction->MNNPoolingAvg))(MNN::poolingAvg<float, Vec8, 8>);
+    // Set min value as 1 << 24
+    coreFunction->MNNPoolingMax = (decltype(coreFunction->MNNPoolingMax))(MNN::poolingMax<float, Vec8, 8, -16777216>);
+    coreFunction->MNNSelectBinaryFunctionForFloat = _AVX2_MNNSelectBinaryFunctionForFloat;
+    coreFunction->MNNCopyC4WithStride = _AVX_MNNCopyC4WithStride;
+    coreFunction->MNNAddC4WithStride = _AVX_MNNAddC4WithStride;
+    coreFunction->MNNScaleAndAddBias = _AVX_MNNScaleAndAddBias;
+    coreFunction->MNNMatrixAdd          = _AVX_MNNMatrixAdd;
+    coreFunction->MNNMatrixSub          = _AVX_MNNMatrixSub;
+
+    coreFunction->MNNConvRunForUnitDepthWise = _AVX_MNNConvRunForUnitDepthWise;
+    coreFunction->MNNConvRunForLineDepthwise = _AVX_MNNConvRunForLineDepthwise;
+    coreFunction->MNNAxByClampBroadcastUnit = _AVX_MNNAxByClampBroadcastUnit;
+    coreFunction->MNNStrassenMergeCFunction = _AVX_MNNStrassenMergeCFunction;
+    coreFunction->MNNMultiAndDestTransformCommon23 = _AVX_MNNMultiAndDestTransformCommon23;
+    coreFunction->MNNSourceTransformCommonF23 = _AVX_MNNSourceTransformCommonF23;
+    coreFunction->MNNConvDwF23MulTransUnit = _AVX_MNNConvDwF23MulTransUnit;
+    coreFunction->MNNReluWithSlopeChannel = _AVX_MNNReluWithSlopeChannel;
+    coreFunction->MNNDeconvRunForLineDepthwise = _AVX_MNNDeconvRunForLineDepthwise;
+    coreFunction->MNNDeconvRunForUnitDepthWise = _AVX_MNNDeconvRunForUnitDepthWise;
+    coreFunction->MNNGridSampleInterp = _AVX_MNNGridSampleInterp;
+
+    // sparse conv funcs
+    coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode;
+    coreFunction->MNNPackedSparseMatMulEpx1 = _AVX_MNNPackedSparseMatMulEpx1EFMA;
+    coreFunction->MNNPackedSparseMatMulEpx4 = _AVX_MNNPackedSparseMatMulEpx4EFMA;
+}

+ 486 - 0
source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp

@@ -0,0 +1,486 @@
+//
+//  ReorderFunctions.cpp
+//  MNN
+//
+//  Created by MNN on 2019/08/25.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#include <float.h>
+#include <string.h>
+#include <algorithm>
+#include <limits>
+#include <vector>
+#include "FunctionSummary.hpp"
+#include "core/Macro.h"
+#include "Vec8.hpp"
+#define PACK_UNIT 8
+
+void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) {
+    auto areaC4  = area / PACK_UNIT;
+    auto depthC4 = depth / PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane = dst + z * dstAreaOffset * PACK_UNIT;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = 0; x < areaC4; ++x) {
+            auto s  = srcPlane + PACK_UNIT * x;
+            auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
+            auto r0 = _mm256_loadu_ps(s + 0 * srcAreaOffset);
+            auto r1 = _mm256_loadu_ps(s + 1 * srcAreaOffset);
+            auto r2 = _mm256_loadu_ps(s + 2 * srcAreaOffset);
+            auto r3 = _mm256_loadu_ps(s + 3 * srcAreaOffset);
+            auto r4 = _mm256_loadu_ps(s + 4 * srcAreaOffset);
+            auto r5 = _mm256_loadu_ps(s + 5 * srcAreaOffset);
+            auto r6 = _mm256_loadu_ps(s + 6 * srcAreaOffset);
+            auto r7 = _mm256_loadu_ps(s + 7 * srcAreaOffset);
+            
+            TRANSPOSE_8x8;
+
+            _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
+            _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
+            _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
+            _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
+            _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
+            _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
+            _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
+            _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
+        }
+    }
+    auto areaRemain  = areaC4 * PACK_UNIT;
+    auto depthRemain = depthC4 * PACK_UNIT;
+    // Down
+    int remain = depth - depthRemain;
+    if (remain > 0) {
+        float* dstPlane       = depthC4 * dstAreaOffset * PACK_UNIT + dst;
+        const float* srcPlane = src + depthC4 * srcAreaOffset * PACK_UNIT;
+        {
+            for (int x = 0; x < areaC4; ++x) {
+                auto s  = srcPlane + PACK_UNIT * x;
+                auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
+                auto r0 = _mm256_loadu_ps(s + 0 * srcAreaOffset);
+                auto r1 = _mm256_setzero_ps();
+                auto r2 = _mm256_setzero_ps();
+                auto r3 = _mm256_setzero_ps();
+                auto r4 = _mm256_setzero_ps();
+                auto r5 = _mm256_setzero_ps();
+                auto r6 = _mm256_setzero_ps();
+                auto r7 = _mm256_setzero_ps();
+                switch (remain) {
+                    case 7:
+                        r6 = _mm256_loadu_ps(s + 6 * srcAreaOffset);
+                    case 6:
+                        r5 = _mm256_loadu_ps(s + 5 * srcAreaOffset);
+                    case 5:
+                        r4 = _mm256_loadu_ps(s + 4 * srcAreaOffset);
+                    case 4:
+                        r3 = _mm256_loadu_ps(s + 3 * srcAreaOffset);
+                    case 3:
+                        r2 = _mm256_loadu_ps(s + 2 * srcAreaOffset);
+                    case 2:
+                        r1 = _mm256_loadu_ps(s + 1 * srcAreaOffset);
+                    default:
+                        break;
+                }
+
+                TRANSPOSE_8x8;
+
+                _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
+                _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
+                _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
+                _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
+                _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
+                _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
+                _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
+                _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
+            }
+        }
+        for (int x = areaRemain; x < area; ++x) {
+            for (int y = 0; y < remain; y++) {
+                dstPlane[PACK_UNIT * x + y] = srcPlane[y * srcAreaOffset + x];
+            }
+            for (int y = remain; y < PACK_UNIT; y++) {
+                dstPlane[PACK_UNIT * x + y] = 0;
+            }
+        }
+    }
+    // Right
+    for (int z = 0; z < depthC4; ++z) {
+        float* dstPlane       = z * dstAreaOffset * PACK_UNIT + dst;
+        const float* srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = areaRemain; x < area; ++x) {
+            float s0 = srcPlane[x];
+            float s1 = srcPlane[x + srcAreaOffset];
+            float s2 = srcPlane[x + srcAreaOffset * 2];
+            float s3 = srcPlane[x + srcAreaOffset * 3];
+            float s4 = srcPlane[x + srcAreaOffset * 4];
+            float s5 = srcPlane[x + srcAreaOffset * 5];
+            float s6 = srcPlane[x + srcAreaOffset * 6];
+            float s7 = srcPlane[x + srcAreaOffset * 7];
+            _mm256_storeu_ps(dstPlane + PACK_UNIT * x, _mm256_set_ps(s7, s6, s5, s4, s3, s2, s1, s0));
+        }
+    }
+}
+void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) {
+    auto areaC4  = area / PACK_UNIT;
+    auto depthC4 = depth / PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane = dst + z * dstAreaOffset * PACK_UNIT;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = 0; x < areaC4; ++x) {
+            auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
+            auto d  = dstPlane + PACK_UNIT * x;
+            auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
+            auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
+            auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
+            auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
+            auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
+            auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
+            auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
+            auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
+
+            TRANSPOSE_8x8;
+
+            _mm256_storeu_ps(d + 0 * dstAreaOffset, t0);
+            _mm256_storeu_ps(d + 1 * dstAreaOffset, t1);
+            _mm256_storeu_ps(d + 2 * dstAreaOffset, t2);
+            _mm256_storeu_ps(d + 3 * dstAreaOffset, t3);
+            _mm256_storeu_ps(d + 4 * dstAreaOffset, t4);
+            _mm256_storeu_ps(d + 5 * dstAreaOffset, t5);
+            _mm256_storeu_ps(d + 6 * dstAreaOffset, t6);
+            _mm256_storeu_ps(d + 7 * dstAreaOffset, t7);
+        }
+    }
+    auto areaRemain  = areaC4 * PACK_UNIT;
+    auto depthRemain = depthC4 * PACK_UNIT;
+    // Down
+    int remain = depth - depthRemain;
+    if (remain > 0) {
+        float* dstPlane       = depthC4 * dstAreaOffset * PACK_UNIT + dst;
+        const float* srcPlane = src + depthC4 * srcAreaOffset * PACK_UNIT;
+        for (int x = 0; x < areaC4; ++x) {
+            auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
+            auto d  = dstPlane + PACK_UNIT * x;
+            auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
+            auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
+            auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
+            auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
+            auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
+            auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
+            auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
+            auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
+
+            TRANSPOSE_8x8;
+
+            switch (remain) {
+                case 7:
+                    _mm256_storeu_ps(d + 6 * dstAreaOffset, t6);
+                case 6:
+                    _mm256_storeu_ps(d + 5 * dstAreaOffset, t5);
+                case 5:
+                    _mm256_storeu_ps(d + 4 * dstAreaOffset, t4);
+                case 4:
+                    _mm256_storeu_ps(d + 3 * dstAreaOffset, t3);
+                case 3:
+                    _mm256_storeu_ps(d + 2 * dstAreaOffset, t2);
+                case 2:
+                    _mm256_storeu_ps(d + 1 * dstAreaOffset, t1);
+                case 1:
+                    _mm256_storeu_ps(d + 0 * dstAreaOffset, t0);
+                default:
+                    break;
+            }
+        }
+        for (int x = areaRemain; x < area; ++x) {
+            for (int y = 0; y < remain; y++) {
+                dstPlane[y * dstAreaOffset + x] = srcPlane[PACK_UNIT * x + y];
+            }
+        }
+    }
+    // Right
+    for (int z = 0; z < depthC4; ++z) {
+        const float* srcPlane = z * srcAreaOffset * PACK_UNIT + src;
+        float* dstPlane       = dst + z * dstAreaOffset * PACK_UNIT;
+        for (int x = areaRemain; x < area; ++x) {
+            for (int y = 0; y < PACK_UNIT; y++) {
+                dstPlane[y * dstAreaOffset + x] = srcPlane[PACK_UNIT * x + y];
+            }
+        }
+    }
+}
+void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) {
+    int c      = (int)depth;
+    int cDiv4  = c / PACK_UNIT;
+    int cAlign = cDiv4 * PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    for (int hi = 0; hi < area; ++hi) {
+        const float* srcHeight = src + hi * c;
+        float* dstHeight       = dst + hi * PACK_UNIT;
+        for (int ci = 0; ci < cDiv4; ++ci) {
+            _mm256_storeu_ps(dstHeight + PACK_UNIT * ci * dstAreaOffset, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci));
+        }
+    }
+
+    if (cAlign == c) {
+        return;
+    }
+
+    int cReamin   = c - cAlign;
+    auto srcAlign = src + cAlign;
+    auto dstAlign = dst + dstAreaOffset * cAlign;
+
+    for (int hi = 0; hi < area; ++hi) {
+        const float* srcHeight = srcAlign + hi * c;
+        float* dstHeight       = dstAlign + hi * PACK_UNIT;
+        for (int i = 0; i < PACK_UNIT; ++i) {
+            dstHeight[i] = 0;
+        }
+        for (int ci = 0; ci < cReamin; ++ci) {
+            dstHeight[ci] = srcHeight[ci];
+        }
+    }
+
+}
+void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) {
+    int c      = (int)depth;
+    int cDiv4  = c / PACK_UNIT;
+    int cAlign = cDiv4 * PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    for (int hi = 0; hi < area; ++hi) {
+        const float* srcHeight = src + hi * PACK_UNIT;
+        float* dstHeight       = dst + hi * c;
+        for (int ci = 0; ci < cDiv4; ++ci) {
+            _mm256_storeu_ps(dstHeight + PACK_UNIT * ci, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci * srcAreaOffset));
+        }
+    }
+
+    if (cAlign == c) {
+        return;
+    }
+
+    int cReamin   = c - cAlign;
+    auto srcAlign = src + srcAreaOffset * cAlign;
+    auto dstAlign = dst + cAlign;
+
+    for (int hi = 0; hi < area; ++hi) {
+        const float* srcHeight = srcAlign + hi * PACK_UNIT;
+        float* dstHeight       = dstAlign + hi * c;
+
+        for (int ci = 0; ci < cReamin; ++ci) {
+            dstHeight[ci] = srcHeight[ci];
+        }
+    }
+}
+
+
+void _AVX_MNNPackCUnitInt8(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset) {
+    auto areaC4  = area / PACK_UNIT;
+    auto depthC4 = depth / PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane = dst + z * dstAreaOffset * PACK_UNIT;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = 0; x < areaC4; ++x) {
+            auto s  = srcPlane + PACK_UNIT * x;
+            auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
+            for (int i=0; i<PACK_UNIT; ++i) {
+                for (int j=0; j<PACK_UNIT; ++j) {
+                    d[PACK_UNIT*i +j] = s[i + j * srcAreaOffset];
+                }
+            }
+        }
+    }
+    auto areaRemain  = areaC4 * PACK_UNIT;
+    auto depthRemain = depthC4 * PACK_UNIT;
+    // Down
+    int remain = depth - depthRemain;
+    if (remain > 0) {
+        auto dstPlane       = depthC4 * dstAreaOffset * PACK_UNIT + dst;
+        const auto srcPlane = src + depthC4 * srcAreaOffset * PACK_UNIT;
+        {
+            for (int x = 0; x < areaC4; ++x) {
+                auto s  = srcPlane + PACK_UNIT * x;
+                auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
+                ::memset(d, 0, PACK_UNIT * PACK_UNIT * sizeof(int8_t));
+                for (int i=0; i<PACK_UNIT; ++i) {
+                    for (int j=0; j<remain; ++j) {
+                        d[PACK_UNIT*i +j] = s[i + j * srcAreaOffset];
+                    }
+                }
+            }
+        }
+        for (int x = areaRemain; x < area; ++x) {
+            for (int y = 0; y < remain; y++) {
+                dstPlane[PACK_UNIT * x + y] = srcPlane[y * srcAreaOffset + x];
+            }
+            for (int y = remain; y < PACK_UNIT; y++) {
+                dstPlane[PACK_UNIT * x + y] = 0;
+            }
+        }
+    }
+    // Right
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane       = z * dstAreaOffset * PACK_UNIT + dst;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = areaRemain; x < area; ++x) {
+            for (int j=0; j<PACK_UNIT; ++j) {
+                dstPlane[PACK_UNIT * x + j] = srcPlane[x + srcAreaOffset * j];
+            }
+        }
+    }
+}
+void _AVX_MNNUnpackCUnitInt8(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset) {
+    auto areaC4  = area / PACK_UNIT;
+    auto depthC4 = depth / PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane = dst + z * dstAreaOffset * PACK_UNIT;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = 0; x < areaC4; ++x) {
+            auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
+            auto d  = dstPlane + PACK_UNIT * x;
+            for (int i=0; i<PACK_UNIT; ++i) {
+                for (int j=0; j<PACK_UNIT; ++j) {
+                    d[i+j*dstAreaOffset] = s[PACK_UNIT*i+j];
+                }
+            }
+        }
+    }
+    auto areaRemain  = areaC4 * PACK_UNIT;
+    auto depthRemain = depthC4 * PACK_UNIT;
+    // Down
+    int remain = depth - depthRemain;
+    if (remain > 0) {
+        auto dstPlane       = dst + depthC4 * dstAreaOffset * PACK_UNIT;
+        const auto srcPlane = src + depthC4 * srcAreaOffset * PACK_UNIT;
+        {
+            for (int x = 0; x < areaC4; ++x) {
+                auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
+                auto d  = dstPlane + PACK_UNIT * x;
+                for (int i=0; i<PACK_UNIT; ++i) {
+                    for (int j=0; j<remain; ++j) {
+                        d[i + j * dstAreaOffset] = s[PACK_UNIT*i +j];
+                    }
+                }
+            }
+        }
+        for (int x = areaRemain; x < area; ++x) {
+            for (int y = 0; y < remain; y++) {
+                 dstPlane[y * dstAreaOffset + x] = srcPlane[PACK_UNIT * x + y];
+            }
+        }
+    }
+    // Right
+    for (int z = 0; z < depthC4; ++z) {
+        auto dstPlane = dst + z * dstAreaOffset * PACK_UNIT;
+        auto srcPlane = src + z * srcAreaOffset * PACK_UNIT;
+        for (int x = areaRemain; x < area; ++x) {
+            for (int j=0; j<PACK_UNIT; ++j) {
+                dstPlane[x + dstAreaOffset * j] = srcPlane[PACK_UNIT * x + j];
+            }
+        }
+    }
+}
+
+void _AVX_MNNPackCUnitTransposeInt8(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset) {
+    int c      = (int)depth;
+    int cDiv4  = c / PACK_UNIT;
+    int cAlign = cDiv4 * PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    if (cAlign == c) {
+        for (int hi = 0; hi < area; ++hi) {
+            const int8_t* srcHeight = src + hi * c;
+            int8_t* dstHeight       = dst + hi * PACK_UNIT;
+            for (int ci = 0; ci < cDiv4; ++ci) {
+                *(int64_t*)(dstHeight + PACK_UNIT * ci * dstAreaOffset) = *(int64_t*)(srcHeight + PACK_UNIT * ci);
+            }
+        }
+        return;
+    }
+    for (int hi = 0; hi < area; ++hi) {
+        const int8_t* srcHeight = src + hi * c;
+        int8_t* dstHeight       = dst + hi * PACK_UNIT;
+        for (int ci = 0; ci < cDiv4; ++ci) {
+            for (int k=0; k<PACK_UNIT; ++k) {
+                dstHeight[PACK_UNIT * ci * dstAreaOffset + k] = srcHeight[PACK_UNIT * ci + k];
+            }
+        }
+    }
+    int cReamin   = c - cAlign;
+    auto srcAlign = src + cAlign;
+    auto dstAlign = dst + dstAreaOffset * cAlign;
+
+    for (int hi = 0; hi < area; ++hi) {
+        const int8_t* srcHeight = srcAlign + hi * c;
+        int8_t* dstHeight       = dstAlign + hi * PACK_UNIT;
+        for (int i = 0; i < PACK_UNIT; ++i) {
+            dstHeight[i] = 0;
+        }
+        for (int ci = 0; ci < cReamin; ++ci) {
+            dstHeight[ci] = srcHeight[ci];
+        }
+    }
+
+}
+
+void _AVX_MNNUnpackCUnitTransposeInt8(int8_t* dst, const int8_t* src, size_t area, size_t depth, int* areaOffset) {
+    int c      = (int)depth;
+    int cDiv4  = c / PACK_UNIT;
+    int cAlign = cDiv4 * PACK_UNIT;
+    auto srcAreaOffset = areaOffset[0];
+    auto dstAreaOffset = areaOffset[1];
+    if (cAlign == c) {
+        for (int hi = 0; hi < area; ++hi) {
+            const int8_t* srcHeight = src + hi * PACK_UNIT;
+            int8_t* dstHeight       = dst + hi * c;
+            for (int ci = 0; ci < cDiv4; ++ci) {
+                *(int64_t*)(dstHeight + PACK_UNIT * ci) = *(int64_t*)(srcHeight + PACK_UNIT * ci * srcAreaOffset);
+            }
+        }
+        return;
+    }
+    for (int hi = 0; hi < area; ++hi) {
+        const int8_t* srcHeight = src + hi * PACK_UNIT;
+        int8_t* dstHeight       = dst + hi * c;
+        for (int ci = 0; ci < cDiv4; ++ci) {
+            for (int k=0; k<PACK_UNIT; ++k) {
+                dstHeight[PACK_UNIT * ci + k] = srcHeight[PACK_UNIT * ci * srcAreaOffset + k];
+            }
+        }
+    }
+    int cReamin   = c - cAlign;
+    auto srcAlign = src + srcAreaOffset * cAlign;
+    auto dstAlign = dst + cAlign;
+
+    for (int hi = 0; hi < area; ++hi) {
+        const int8_t* srcHeight = srcAlign + hi * PACK_UNIT;
+        int8_t* dstHeight       = dstAlign + hi * c;
+
+        for (int ci = 0; ci < cReamin; ++ci) {
+            dstHeight[ci] = srcHeight[ci];
+        }
+    }
+}
+void _AVX_ReorderInit(void* functions) {
+    auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
+    coreFunction->MNNPackCUnit = _AVX_MNNPackCUnit;
+    coreFunction->MNNUnpackCUnit = _AVX_MNNUnpackCUnit;
+    coreFunction->MNNPackCUnitTranspose = _AVX_MNNPackCUnitTranspose;
+    coreFunction->MNNUnpackCUnitTranspose = _AVX_MNNUnpackCUnitTranspose;
+
+    coreFunction->MNNUnpackCUnitTransposeInt8 = _AVX_MNNUnpackCUnitTransposeInt8;
+    coreFunction->MNNPackCUnitInt8 = _AVX_MNNPackCUnitInt8;
+    coreFunction->MNNUnpackCUnitInt8 = _AVX_MNNUnpackCUnitInt8;
+    coreFunction->MNNPackCUnitTransposeInt8 = _AVX_MNNPackCUnitTransposeInt8;
+}

+ 65 - 63
source/backend/cpu/x86_x64/avx/Vec8.hpp

@@ -2,13 +2,74 @@
 //  Vec8.hpp
 //  MNN
 //
-//  Created by MNN on b'2021/05/16'.
+//  Created by MNN on 2021/05/16.
 //  Copyright © 2018, Alibaba Group Holding Limited
 //
 
 #ifndef Vec8_hpp
 #define Vec8_hpp
 #include "FunctionSummary.hpp"
+
+#define TRANSPOSE_8x8 \
+t0 = _mm256_unpacklo_ps(r0, r1);\
+t1 = _mm256_unpackhi_ps(r0, r1);\
+t2 = _mm256_unpacklo_ps(r2, r3);\
+t3 = _mm256_unpackhi_ps(r2, r3);\
+t4 = _mm256_unpacklo_ps(r4, r5);\
+t5 = _mm256_unpackhi_ps(r4, r5);\
+t6 = _mm256_unpacklo_ps(r6, r7);\
+t7 = _mm256_unpackhi_ps(r6, r7);\
+\
+r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0));\
+r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2));\
+r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0));\
+r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2));\
+r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0));\
+r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2));\
+r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0));\
+r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2));\
+\
+t0 = _mm256_permute2f128_ps(r0, r4, 0x20);\
+t1 = _mm256_permute2f128_ps(r1, r5, 0x20);\
+t2 = _mm256_permute2f128_ps(r2, r6, 0x20);\
+t3 = _mm256_permute2f128_ps(r3, r7, 0x20);\
+t4 = _mm256_permute2f128_ps(r0, r4, 0x31);\
+t5 = _mm256_permute2f128_ps(r1, r5, 0x31);\
+t6 = _mm256_permute2f128_ps(r2, r6, 0x31);\
+t7 = _mm256_permute2f128_ps(r3, r7, 0x31);\
+
+#define TRANSPOSE_8x8_REPLACE(r0, r1, r2, r3, r4, r5, r6, r7) \
+{\
+auto t0 = _mm256_unpacklo_ps(r0, r1);\
+auto t1 = _mm256_unpackhi_ps(r0, r1);\
+auto t2 = _mm256_unpacklo_ps(r2, r3);\
+auto t3 = _mm256_unpackhi_ps(r2, r3);\
+auto t4 = _mm256_unpacklo_ps(r4, r5);\
+auto t5 = _mm256_unpackhi_ps(r4, r5);\
+auto t6 = _mm256_unpacklo_ps(r6, r7);\
+auto t7 = _mm256_unpackhi_ps(r6, r7);\
+\
+r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0));\
+r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2));\
+r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0));\
+r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2));\
+r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0));\
+r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2));\
+r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0));\
+r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2));\
+\
+t0 = _mm256_permute2f128_ps(r0, r4, 0x20);\
+t1 = _mm256_permute2f128_ps(r1, r5, 0x20);\
+t2 = _mm256_permute2f128_ps(r2, r6, 0x20);\
+t3 = _mm256_permute2f128_ps(r3, r7, 0x20);\
+t4 = _mm256_permute2f128_ps(r0, r4, 0x31);\
+t5 = _mm256_permute2f128_ps(r1, r5, 0x31);\
+t6 = _mm256_permute2f128_ps(r2, r6, 0x31);\
+t7 = _mm256_permute2f128_ps(r3, r7, 0x31);\
+r0 = t0, r1 = t1, r2 = t2, r3 = t3;\
+r4 = t4, r5 = t5, r6 = t6, r7 = t7;\
+}\
+
 struct Vec8 {
     using VecType = Vec8;
     __m256 value;
@@ -77,68 +138,9 @@ struct Vec8 {
         VecType dst = { _mm256_min_ps(v1.value, v2.value) };
         return dst;
     }
+    static void transpose8(VecType& v0, VecType& v1, VecType& v2, VecType& v3, VecType& v4, VecType& v5, VecType& v6, VecType& v7) {
+        TRANSPOSE_8x8_REPLACE(v0.value, v1.value, v2.value, v3.value, v4.value, v5.value, v6.value, v7.value);
+    }
 };
 
-#define TRANSPOSE_8x8 \
-t0 = _mm256_unpacklo_ps(r0, r1);\
-t1 = _mm256_unpackhi_ps(r0, r1);\
-t2 = _mm256_unpacklo_ps(r2, r3);\
-t3 = _mm256_unpackhi_ps(r2, r3);\
-t4 = _mm256_unpacklo_ps(r4, r5);\
-t5 = _mm256_unpackhi_ps(r4, r5);\
-t6 = _mm256_unpacklo_ps(r6, r7);\
-t7 = _mm256_unpackhi_ps(r6, r7);\
-\
-r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0));\
-r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2));\
-r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0));\
-r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2));\
-r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0));\
-r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2));\
-r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0));\
-r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2));\
-\
-t0 = _mm256_permute2f128_ps(r0, r4, 0x20);\
-t1 = _mm256_permute2f128_ps(r1, r5, 0x20);\
-t2 = _mm256_permute2f128_ps(r2, r6, 0x20);\
-t3 = _mm256_permute2f128_ps(r3, r7, 0x20);\
-t4 = _mm256_permute2f128_ps(r0, r4, 0x31);\
-t5 = _mm256_permute2f128_ps(r1, r5, 0x31);\
-t6 = _mm256_permute2f128_ps(r2, r6, 0x31);\
-t7 = _mm256_permute2f128_ps(r3, r7, 0x31);\
-
-#define TRANSPOSE_8x8_REPLACE(r0, r1, r2, r3, r4, r5, r6, r7) \
-{\
-auto t0 = _mm256_unpacklo_ps(r0, r1);\
-auto t1 = _mm256_unpackhi_ps(r0, r1);\
-auto t2 = _mm256_unpacklo_ps(r2, r3);\
-auto t3 = _mm256_unpackhi_ps(r2, r3);\
-auto t4 = _mm256_unpacklo_ps(r4, r5);\
-auto t5 = _mm256_unpackhi_ps(r4, r5);\
-auto t6 = _mm256_unpacklo_ps(r6, r7);\
-auto t7 = _mm256_unpackhi_ps(r6, r7);\
-\
-r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0));\
-r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2));\
-r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0));\
-r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2));\
-r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0));\
-r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2));\
-r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0));\
-r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2));\
-\
-t0 = _mm256_permute2f128_ps(r0, r4, 0x20);\
-t1 = _mm256_permute2f128_ps(r1, r5, 0x20);\
-t2 = _mm256_permute2f128_ps(r2, r6, 0x20);\
-t3 = _mm256_permute2f128_ps(r3, r7, 0x20);\
-t4 = _mm256_permute2f128_ps(r0, r4, 0x31);\
-t5 = _mm256_permute2f128_ps(r1, r5, 0x31);\
-t6 = _mm256_permute2f128_ps(r2, r6, 0x31);\
-t7 = _mm256_permute2f128_ps(r3, r7, 0x31);\
-r0 = t0, r1 = t1, r2 = t2, r3 = t3;\
-r4 = t4, r5 = t5, r6 = t6, r7 = t7;\
-}\
-
-
 #endif
-

+ 0 - 0
source/backend/cpu/x86_x64/avx/WinogradAVX2.cpp


Some files were not shown because too many files changed in this diff