Pārlūkot izejas kodu

[MNN:Sync] Sync Internal 2.8.0

zhaode.wzd 1 gadu atpakaļ
vecāks
revīzija
387775be2a
100 mainītis faili ar 7876 papildinājumiem un 688 dzēšanām
  1. 8 1
      CMakeLists.txt
  2. 1 3
      docs/compile/cmake.md
  3. 3 3
      docs/compile/pymnn.md
  4. 2 4
      docs/compile/tools.md
  5. 35 0
      docs/inference/module.md
  6. 47 1
      docs/pymnn/expr.md
  7. 3 2
      docs/tools/test.md
  8. 2 2
      express/Executor.cpp
  9. 30 26
      express/Expr.cpp
  10. 0 1
      express/NeuralNetWorkOp.cpp
  11. 1 1
      express/Utils.cpp
  12. 1 1
      express/module/PipelineModule.cpp
  13. 5 1
      express/module/StaticModule.cpp
  14. 3 2
      include/MNN/HalideRuntime.h
  15. 2 2
      include/MNN/MNNDefine.h
  16. 8 7
      include/MNN/expr/Executor.hpp
  17. 40 35
      package_scripts/win/build_lib.ps1
  18. 21 17
      package_scripts/win/build_lib_release.ps1
  19. 16 8
      project/ios/MNN.xcodeproj/project.pbxproj
  20. 22 0
      pymnn/examples/MNNTrain/simple/grad_loss.py
  21. 43 0
      pymnn/examples/MNNTrain/simple/make_solve_equation_graph.py
  22. 8 0
      pymnn/pip_package/MNN/nn/__init__.py
  23. 55 23
      pymnn/pip_package/build_deps.py
  24. 53 25
      pymnn/pip_package/setup.py
  25. 25 2
      pymnn/src/expr.h
  26. 35 0
      pymnn/src/nn.h
  27. 76 1
      pymnn/src/optim.h
  28. 292 0
      schema/current/CaffeOp_generated.h
  29. 125 25
      schema/current/MNN_generated.h
  30. 394 0
      schema/current/TrainInfo_generated.h
  31. 32 11
      schema/current/UserDefine_generated.h
  32. 14 0
      schema/default/CaffeOp.fbs
  33. 8 4
      schema/default/MNN.fbs
  34. 18 0
      schema/default/TrainInfo.fbs
  35. 3 1
      schema/default/UserDefine.fbs
  36. 6 6
      source/backend/arm82/Arm82Binary.cpp
  37. 33 66
      source/backend/arm82/Arm82Functions.cpp
  38. 4 2
      source/backend/arm82/asm/arm32/MNNGeluFP16.S
  39. 5 2
      source/backend/arm82/asm/arm64/MNNGeluFP16.S
  40. 247 0
      source/backend/arm82/asm/arm64/low_memory/MNNAbsMaxFP16.S
  41. 393 0
      source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S
  42. 361 0
      source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_sdot.S
  43. 894 0
      source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_smmla.S
  44. 323 0
      source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_sdot.S
  45. 566 0
      source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_smmla.S
  46. 189 0
      source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S
  47. 106 0
      source/backend/arm82/asm/arm64/low_memory/MNNQuantSumFP16.S
  48. 47 6
      source/backend/coreml/execution/CoreMLConvolution.cpp
  49. 2 0
      source/backend/coreml/execution/CoreMLConvolution.hpp
  50. 4 0
      source/backend/cpu/CMakeLists.txt
  51. 21 13
      source/backend/cpu/CPUBackend.cpp
  52. 9 1
      source/backend/cpu/CPUBackend.hpp
  53. 1 0
      source/backend/cpu/CPUBinaryInt8.cpp
  54. 4 1
      source/backend/cpu/CPUCast.cpp
  55. 75 22
      source/backend/cpu/CPUDeconvolution.cpp
  56. 13 10
      source/backend/cpu/CPUDeconvolution.hpp
  57. 1 1
      source/backend/cpu/CPUDeconvolutionDepthwise.cpp
  58. 87 0
      source/backend/cpu/CPUDequantizeLinear.cpp
  59. 81 0
      source/backend/cpu/CPUDequantizeLinear.hpp
  60. 77 0
      source/backend/cpu/CPUGridSample.cpp
  61. 1 1
      source/backend/cpu/CPUGridSample.hpp
  62. 0 10
      source/backend/cpu/CPUImageProcess.cpp
  63. 4 12
      source/backend/cpu/CPUMatMul.cpp
  64. 15 0
      source/backend/cpu/CPUOPRegister.cpp
  65. 85 0
      source/backend/cpu/CPUQuantizeLinear.cpp
  66. 31 0
      source/backend/cpu/CPUQuantizeLinear.hpp
  67. 23 9
      source/backend/cpu/CPURaster.cpp
  68. 2 2
      source/backend/cpu/CPUUnary.cpp
  69. 124 0
      source/backend/cpu/GridSampler.hpp
  70. 28 28
      source/backend/cpu/UnaryUtils.hpp
  71. 34 0
      source/backend/cpu/arm/CommonOptFunctionNeon.cpp
  72. 56 31
      source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S
  73. 2 2
      source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S
  74. 56 30
      source/backend/cpu/arm/arm32/MNNBinarySubInt8.S
  75. 5 2
      source/backend/cpu/arm/arm32/MNNGelu.S
  76. 135 0
      source/backend/cpu/arm/arm32/MNNTranspose16Bit8x8.S
  77. 5 2
      source/backend/cpu/arm/arm32/bf16/MNNGelu_BF16.S
  78. 123 60
      source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S
  79. 1 1
      source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S
  80. 1 1
      source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S
  81. 4 1
      source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S
  82. 1 1
      source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S
  83. 125 60
      source/backend/cpu/arm/arm64/MNNBinarySubInt8.S
  84. 5 2
      source/backend/cpu/arm/arm64/MNNGelu.S
  85. 1 1
      source/backend/cpu/arm/arm64/MNNSoftmax.S
  86. 115 0
      source/backend/cpu/arm/arm64/MNNTranspose16Bit8x8.S
  87. 5 2
      source/backend/cpu/arm/arm64/bf16/MNNGelu_BF16.S
  88. 173 0
      source/backend/cpu/arm/arm64/low_memory/MNNAbsMaxFP32.S
  89. 155 0
      source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S
  90. 232 0
      source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_sdot.S
  91. 373 0
      source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_smmla.S
  92. 209 0
      source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_sdot.S
  93. 314 0
      source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_smmla.S
  94. 102 0
      source/backend/cpu/arm/arm64/low_memory/MNNQuantScaleFP32.S
  95. 100 0
      source/backend/cpu/arm/arm64/low_memory/MNNQuantSumFP32.S
  96. 6 6
      source/backend/cpu/bf16/BF16Binary.cpp
  97. 212 74
      source/backend/cpu/compute/CommonOptFunction.cpp
  98. 23 11
      source/backend/cpu/compute/CommonOptFunction.h
  99. 10 0
      source/backend/cpu/compute/ConvolutionFloatFactory.cpp
  100. 0 0
      source/backend/cpu/compute/ConvolutionHybrid.cpp

+ 8 - 1
CMakeLists.txt

@@ -43,6 +43,7 @@ option(MNN_SUPPORT_DEPRECATED_OP "Enable MNN's tflite quantized op" ON)
 option(MNN_DEBUG_MEMORY "MNN Debug Memory Access" OFF)
 option(MNN_DEBUG_TENSOR_SIZE "Enable Tensor Size" OFF)
 option(MNN_GPU_TRACE "Enable MNN Gpu Debug" OFF)
+option(MNN_SUPPORT_RENDER "Enable MNN Render Ops" OFF)
 option(MNN_PORTABLE_BUILD "Link the static version of third party libraries where possible to improve the portability of built executables" OFF)
 option(MNN_SEP_BUILD "Build MNN Backends and expression separately. Only works with MNN_BUILD_SHARED_LIBS=ON" ON)
 option(NATIVE_LIBRARY_OUTPUT "Native Library Path" OFF)
@@ -162,6 +163,9 @@ endif()
 if(MNN_SUPPORT_DEPRECATED_OP)
     add_definitions(-DMNN_SUPPORT_DEPRECATED_OP)
 endif()
+if(MNN_SUPPORT_RENDER)
+    add_definitions(-DMNN_SUPPORT_RENDER)
+endif()
 
 # debug options
 if(MNN_DEBUG_MEMORY)
@@ -372,7 +376,7 @@ list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNMath>)
 list(APPEND MNN_TARGETS MNNMath)
 
 # Transform
-FILE(GLOB MNN_Transform_SRC ${CMAKE_CURRENT_LIST_DIR}/source/shape/* ${CMAKE_CURRENT_LIST_DIR}/source/geometry/*)
+FILE(GLOB_RECURSE MNN_Transform_SRC ${CMAKE_CURRENT_LIST_DIR}/source/shape/* ${CMAKE_CURRENT_LIST_DIR}/source/geometry/*)
 add_library(MNNTransform OBJECT ${MNN_Transform_SRC})
 IF (NOT MNN_BUILD_MINI)
     list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNTransform>)
@@ -601,9 +605,12 @@ IF(MNN_BUILD_TRAIN OR MNN_BUILD_QUANTOOLS)
   add_subdirectory(tools/train)
   IF(MNN_SEP_BUILD)
     list(APPEND MNN_DEPS MNNTrain)
+    list(APPEND MNN_DEPS MNNTrainUtils)
   ELSE()
     list(APPEND MNN_TARGETS MNNTrain)
+    list(APPEND MNN_TARGETS MNNTrainUtils)
     list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNTrain>)
+    list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNTrainUtils>)
   ENDIF()
 ENDIF()
 

+ 1 - 3
docs/compile/cmake.md

@@ -63,11 +63,8 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
 | MNN_VULKAN_DEBUG     | 是否打开Vulkan的DEBUG模式,该宏仅在`MNN_VULKAN=ON`时生效,默认为`OFF` |
 | MNN_OPENGL_REGEN     | 是否重新生成OpenGL Kenel,该宏仅在`MNN_OPENGL=ON`时生效,默认为`OFF` |
 | MNN_TRT_DYNAMIC      | 是否通过dlopen的方式引入TRT的动态库,该宏仅在`MNN_TENSORRT=ON`时生效,默认为`OFF |
-| TF_CONVERT_ORIGIN    | 构建的`MNNConvert`是否使用原始TF转换模式,该宏仅在`MNN_BUILD_CONVERTER=ON`时生效,默认为`OFF` |
-| TFMODEL_OPTIMIZE     | 构建的`MNNConvert`是否对Tensorflow模型执行优化,该宏仅在`MNN_BUILD_CONVERTER=ON`时生效,默认为`OFF` |
 | MNN_BUILD_TORCH      | 构建的`MNNConvert`是否支持`TorchScript`,该宏仅在`MNN_BUILD_CONVERTER=ON`时生效,默认为`OFF` |
 | MNN_TRAIN_DEBUG      | 构建的训练模块是否支持调试,该宏仅在`MNN_BUILD_TRAIN=ON`时生效,默认为`OFF` |
-| MNN_BUILD_TRAIN_MINI | 构建删减版训练模块,不构建`Dataset`与`model`,该宏仅在`MNN_BUILD_TRAIN=ON`时生效,默认为`OFF` |
 | MNN_USE_OPENCV       | 构建的训练Demo是否使用`OpenCV`依赖,该宏仅在`MNN_BUILD_TRAIN=ON`时生效,默认为`OFF` |
 | MNN_IMGPROC_COLOR    | 构建MNN的OpenCV功能是否开启`颜色空间转换`,默认为`ON` |
 | MNN_IMGPROC_GEOMETRIC | 构建MNN的OpenCV功能是否开启`形变`,默认为`ON` |
@@ -83,4 +80,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
 | MNN_OPENCV_BENCH     | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` |
 | MNN_VULKAN_IMAGE     | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` |
 | MNN_LOW_MEMORY       | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` |
+| MNN_SUPPORT_RENDER       | 是否支持图形渲染相关算子实现,默认为 `OFF` |
 | MNN_BUILD_LLM        | 是否构建基于MNN的llm库和demo,默认为`OFF` |

+ 3 - 3
docs/compile/pymnn.md

@@ -2,8 +2,8 @@
 ## 本地安装
 ```bash
 cd /path/to/MNN/pymnn/pip_package
-python build_deps.py
-python setup.py install --version {MNN版本}
+python build_deps.py {MNN依赖包组合} #internal,cuda,trt,cuda_tune,opencl,vulkan,render,no_sse,torch这几个字符串的任意组合,例如字符串可为:"cuda,reder,no_sse"
+python setup.py install --version {MNN版本} --deps {MNN依赖包组合}
 ```
 ## 构建Python Wheel包
 - Linux
@@ -41,4 +41,4 @@ python setup.py install --version {MNN版本}
     .\package_scripts\win\build_whl.ps1 -version {MNN版本} -backends "opencl,vulkan" -path MNN-CPU-OPENCL/py_whl/x64 -pyenvs "py27,py37,py38,py39"
     # CPU+OpenCL+Vulkan,32位编译
     .\package_scripts\win\build_whl.ps1 -version {MNN版本} -backends "opencl,vulkan" -x86 -path MNN-CPU-OPENCL/py_whl/x86 -pyenvs "py27-win32,py37-win32,py38-win32,py39-win32"
-    ```
+    ```

+ 2 - 4
docs/compile/tools.md

@@ -29,10 +29,8 @@
 - 编译产物
   - `MNNTrain` 训练框架库
   - `runTrainDemo.out` 运行训练框架demo的入口程序
-  - `transformer.out` 训练模型转换器
-  - `train.out` 训练功能入口程序
-  - `rawDataTransform.out` 将json文件转换为flatbuffers文件
-  - `dataTransformer.out` 将图片转换为flatbuffers文件
+  - `transformer` 训练模型转换器,将推理用的MNN模型转换为执行训练的MNN模型
+  - `extractForInfer` 从执行训练的MNN模型中提取参数,对应更新推理用的MNN模型
 ## 测试工具
 - 相关编译选项
   - `MNN_BUILD_TOOL` 是否编译测试工具

+ 35 - 0
docs/inference/module.md

@@ -56,6 +56,41 @@ std::unique_ptr<Module> module; // module
 module.reset(Module::load(input_names, output_names, model_filename.c_str(), rtMgr, &mdconfig));
 ```
 
+### Module::Config 
+创建`Module`时可传入`Module::Config`,具体结构如下:
+
+```cpp
+struct Config {
+    // Load module as dynamic, default static
+    bool dynamic = false;
+
+    // for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily
+    bool shapeMutable = true;
+    // Pre-rearrange weights or not. Disabled by default.
+    // The weights will be rearranged in a general way, so the best implementation
+    // may not be adopted if `rearrange` is enabled.
+    bool rearrange = false;
+
+    BackendInfo* backend = nullptr;
+};
+```
+
+#### dynamic
+- 默认为 false ,输出的变量为const ,只能得到数据
+- 若 dynamic = true ,加载出的模型将按动态图方式运行,会增加额外构图耗时,但可以保存输出变量的计算路径,存成模型
+- 若 dynamic = true ,后面的 shapeMutable / rearrange 不再生效
+
+#### shapeMutable
+- 默认为 true ,表示输入形状易变,将延迟进行形状相关计算
+- 设置为 false 时,会提前申请内存,在 onForward 时做输入数据的拷贝而不是直接使用指针
+
+#### rearrange
+- 若为 true ,在创建 Module 时会预先创建卷积算子,做权重重排,以降低运行时的内存
+- 目前只支持 CPU 和 CUDA 后端
+
+#### backend
+已经废弃,不要设置此项
+
 ### 获取模型信息
 调用`getInfo`函数可获取`Module`信息,可以参考代码:`tools/cpp/GetMNNInfo.cpp`,[工具](../tools/test.html#getmnninfo)
 ```cpp

+ 47 - 1
docs/pymnn/expr.md

@@ -145,6 +145,52 @@ array([0., 1., 2., 3.], dtype=float32)
 'Input'
 ```
 ---
+### `set_lazy_mode(mode)`
+设置惰性计算的模式,仅在开启惰性求值的状态下生效,
+
+- 0 : 所有计算均延迟执行
+- 1 : 立即进行几何计算,内容计算延迟执行,适用于构建静态模型或训练时求导
+
+默认为0
+
+
+参数:
+- `x:int` 模式类型
+
+返回:`None`
+
+返回类型:`None`
+
+示例:
+```python
+>>> expr.lazy_eval(True)
+>>> expr.set_lazy_mode(0)
+>>> y = expr.concat([x], -1)
+>>> expr.save([y], "concat.mnn") # 模型中为 concat 算子
+>>> expr.set_lazy_mode(1)
+>>> y = expr.concat([x], -1)
+>>> expr.save([y], "concat_static.mnn") # 模型中为 raster 算子
+```
+
+---
+### `set_global_executor_config(backend, precision, threadnum)`
+设置expr运行后端、精度、线程数(gpu代表mode):
+
+参数:
+- `backend:int` 例如:0->CPU 1->Metal 2->CUDA 3->OPENCL 
+- `precision:int` 例如:0—>Normal 1->High 2->Low 
+- `threadnum:int` 例如:CPU表示线程数  GPU表示Mode
+
+返回:`None`
+
+返回类型:`None`
+
+示例:
+
+```python
+>>> expr.set_global_executor_config(2, 2, 1)
+```
+---
 ### `sign(x)`
 返回输入值的符号,正数返回1,负数返回-1
 
@@ -3054,4 +3100,4 @@ dict_keys(['conv1', 'conv2_1/dw', 'conv2_1/sep', 'conv2_2/dw', 'conv2_2/sep', 'c
 dict_keys(['data'])
 >>> outputs.keys()
 dict_keys(['prob'])
-```
+```

+ 3 - 2
docs/tools/test.md

@@ -87,11 +87,12 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms
 - 16 : 适用于使用 GPU 的情况,由 MNN 优先选择 CPU 运行,并将 GPU 的 tuning 信息存到 cache 文件,所有算子 tuning 完成则启用 GPU
 - 32 : rearrange 设为 true ,降低模型加载后的内存大小,但会增加模型加载的初始化时间
 - 64 : 创建模型后,clone 出一个新的模型运行,用于测试 clone 功能(主要用于多并发推理)的正确性
+- 128 : 使用文件夹下面的 input.mnn 和 output.mnn 做为输入和对比输出,对于数据量较大的情况宜用此方案
 
 
 ### 示例
 ```bash
-$ python ../tools/script/fastTestOnnx.py mobilenetv2-7.onnx
+$ python ../tools/script/testMNNFromOnnx.py mobilenetv2-7.onnx
 $ ./ModuleBasic.out mobilenetv2-7.mnn onnx 0 0 10   
 Test mobilenetv2-7.mnn from input info: onnx
 input
@@ -114,7 +115,7 @@ Avg= 9.946699 ms, min= 9.472000 ms, max= 10.227000 ms
 - `model:str` 模型文件路径
 - `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT)
 - `shapeMutable:int` 输入形状是否可变
-- `dir_n:str` 输入输出信息文件夹,可使用 fastTestOnnx.py / fastTestTf.py / fastTestTflite.py 等脚本生成,参考模型转换的正确性校验部分
+- `dir_n:str` 输入输出信息文件夹,可使用 testMNNFromOnnx.py 等脚本生成,参考模型转换的正确性校验部分
 ```bash
 ./SequenceModuleTest.out transformer.mnn 0 1 tr tr1 tr2 tr3 tr4 > error.txt
 ```

+ 2 - 2
express/Executor.cpp

@@ -145,6 +145,7 @@ std::shared_ptr<Executor> Executor::getGlobalExecutor() {
         info.type = MNN_FORWARD_CPU;
         info.numThread = 1;
         std::shared_ptr<Runtime> bn(creator->onCreate(info));
+        bn->setAllocatorType(info.allocator);
         gExecutor = new std::shared_ptr<Executor>(new Executor(bn, MNN_FORWARD_CPU, 1));
     });
     return *gExecutor;
@@ -668,10 +669,9 @@ std::shared_ptr<Executor::SubGraph> Executor::findSubGraph(const std::string& su
     }
     return iter->second;
 }
-void Executor::setLazyComputeMode(LazyMode mode) {
+void Executor::setLazyComputeMode(uint32_t mode) {
     mLazyMode = mode;
 }
 
-
 } // namespace Express
 } // namespace MNN

+ 30 - 26
express/Expr.cpp

@@ -193,8 +193,11 @@ EXPRP Expr::create(std::shared_ptr<BufferStorage> extra, std::vector<VARP>&& inp
     expr->mStorage = extra;
     expr->mOp = flatbuffers::GetRoot<Op>(extra->buffer());
     expr->mInputs   = std::move(inputs);
-    expr->mInside->mReq = ExecutorScope::Current()->getRequirement(expr.get());
-    _addLinkForInputs(expr);
+    auto exe = ExecutorScope::Current();
+    expr->mInside->mReq = exe->getRequirement(expr.get());
+    if (!(exe->getLazyMode() & Executor::LAZY_COMPUTE_ONCE)) {
+        _addLinkForInputs(expr);
+    }
     return expr;
 }
 
@@ -350,7 +353,7 @@ VARP Variable::create(EXPRP expr, int index) {
     }
     // CONTENT Mode
     do {
-        if (executor->getLazyMode() != Executor::LAZY_CONTENT) {
+        if (!(executor->getLazyMode() & Executor::LAZY_CONTENT)) {
             break;
         }
         if (expr->get() == nullptr) {
@@ -1016,7 +1019,6 @@ blob->dataType = DataType_DT_##TYPE;
 
 void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
     auto executeOrder = getExecuteOrder(vars);
-
     // Search subgraphs
     std::map<std::string, std::shared_ptr<Executor::SubGraph>> subgraphs;
     auto exe = ExecutorScope::Current();
@@ -1086,15 +1088,9 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
                 blob->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info.order);
                 blob->dims       = info.dim;
                 if (info.type.code == halide_type_float) {
-                    if (info.type.bits == 16) {
-                        blob->dataType = DataType_DT_BFLOAT16;
-                        blob->uint8s.resize(info.size * 2);
-                        ::memcpy(blob->uint8s.data(), ptr, info.size * sizeof(int16_t));
-                    } else {
-                        blob->dataType = DataType_DT_FLOAT;
-                        blob->float32s.resize(info.size);
-                        ::memcpy(blob->float32s.data(), ptr, info.size * sizeof(float));
-                    }
+                    blob->dataType = DataType_DT_FLOAT;
+                    blob->float32s.resize(info.size);
+                    ::memcpy(blob->float32s.data(), ptr, info.size * sizeof(float));
                 } else if (info.type.code == halide_type_int && info.type.bits == 32) {
                     blob->dataType = DataType_DT_INT32;
                     blob->int32s.resize(info.size);
@@ -1107,6 +1103,10 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
                     blob->dataType = DataType_DT_UINT8;
                     blob->uint8s.resize(info.size);
                     ::memcpy(blob->uint8s.data(), ptr, info.size * sizeof(uint8_t));
+                } else if (info.type.code == halide_type_bfloat && info.type.bits == 16) {
+                    blob->dataType = DataType_DT_BFLOAT16;
+                    blob->uint8s.resize(info.size * 2);
+                    ::memcpy(blob->uint8s.data(), ptr, info.size * sizeof(int16_t));
                 }
                 op->type       = OpType_Const;
                 if (expr->mType == VARP::TRAINABLE) {
@@ -1163,12 +1163,14 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
                     dest->tensorName[subindex] = op->name + numberToString(v);
                 }
             }
-            if (staticModel) {
-                auto tensor = expr->inside()->mOutputTensors[v];
+            auto tensor = expr->inside()->mOutputTensors[v];
+
+            if (staticModel || TensorUtils::getDescribe(tensor)->quantAttr) {
                 auto des = TensorUtils::getDescribe(tensor);
                 auto describe = std::unique_ptr<MNN::TensorDescribeT>(new MNN::TensorDescribeT);
                 describe->index = varIndexInfo[expr] + v;
                 describe->blob = std::unique_ptr<MNN::BlobT>(new MNN::BlobT);
+                describe->name = dest->tensorName[subindex];
                 auto& blob = describe->blob;
                 blob->dataFormat = des->dimensionFormat;
                 if (tensor->getType() == halide_type_of<float>()) {
@@ -1190,18 +1192,20 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
                     describe->quantInfo->zero = tensorDes->quantAttr->zero;
                     describe->quantInfo->scale = tensorDes->quantAttr->scale;
                 }
-                for (auto& reg : des->regions) {
-                    auto regionT = std::unique_ptr<MNN::RegionT>(new MNN::RegionT);
-                    regionT->src = std::unique_ptr<MNN::ViewT>(new MNN::ViewT);
-                    regionT->dst = std::unique_ptr<MNN::ViewT>(new MNN::ViewT);
-                    regionT->src->offset = reg.src.offset;
-                    regionT->dst->offset = reg.dst.offset;
-                    for (int s = 0; s < 3; s++) {
-                        regionT->src->stride.push_back(reg.src.stride[s]);
-                        regionT->dst->stride.push_back(reg.dst.stride[s]);
-                        regionT->size.push_back(reg.size[s]);
+                if (staticModel) {
+                    for (auto& reg : des->regions) {
+                        auto regionT = std::unique_ptr<MNN::RegionT>(new MNN::RegionT);
+                        regionT->src = std::unique_ptr<MNN::ViewT>(new MNN::ViewT);
+                        regionT->dst = std::unique_ptr<MNN::ViewT>(new MNN::ViewT);
+                        regionT->src->offset = reg.src.offset;
+                        regionT->dst->offset = reg.dst.offset;
+                        for (int s = 0; s < 3; s++) {
+                            regionT->src->stride.push_back(reg.src.stride[s]);
+                            regionT->dst->stride.push_back(reg.dst.stride[s]);
+                            regionT->size.push_back(reg.size[s]);
+                        }
+                        describe->regions.emplace_back(std::move(regionT));
                     }
-                    describe->regions.emplace_back(std::move(regionT));
                 }
                 dest->extraTensorDescribe.emplace_back(std::move(describe));
             }

+ 0 - 1
express/NeuralNetWorkOp.cpp

@@ -1327,7 +1327,6 @@ VARP _Range(VARP start, VARP limit, VARP delta) {
     std::unique_ptr<OpT> op(new OpT);
     op->type       = OpType_Range;
     auto rangeParam = new RangeT;
-    rangeParam->Tidx = (MNN::DataType)Utils::convertDataType(start->getInfo()->type);
     op->main.type = OpParameter_Range;
     op->main.value = rangeParam;
     return Variable::create(Expr::create(std::move(op), {start, limit, delta}));

+ 1 - 1
express/Utils.cpp

@@ -81,7 +81,7 @@ halide_type_t Utils::revertDataType(DataType dataType) {
     CONVERT(DataType_DT_UINT8, halide_type_of<uint8_t>(), dataType);
     CONVERT(DataType_DT_INT8, halide_type_of<int8_t>(), dataType);
     CONVERT(DataType_DT_HALF, halide_type_of<float>(), dataType);
-    CONVERT(DataType_DT_BFLOAT16, halide_type_t(halide_type_float, 16), dataType);
+    CONVERT(DataType_DT_BFLOAT16, halide_type_t(halide_type_bfloat, 16), dataType);
     return halide_type_of<float>();
 }
 Express::Dimensionformat Utils::revertFormat(int format) {

+ 1 - 1
express/module/PipelineModule.cpp

@@ -518,7 +518,7 @@ static Module* _createSubModule(std::shared_ptr<BufferStorage> bufferStorage, co
     scheduleInfo.defaultBackend = sharedConst->defaultBackend;
     scheduleInfo.constReplaceBackend = sharedConst->constReplaceBackend;
     scheduleInfo.allTensors = sharedConst->allTensors;
-    initTensors(scheduleInfo.allTensors, net);
+    scheduleInfo.validForResize = initTensors(scheduleInfo.allTensors, net);
     std::vector<Schedule::OpCacheInfo> oplists;
     std::vector<const Op*> ops;
     ops.reserve(info.opList.size());

+ 5 - 1
express/module/StaticModule.cpp

@@ -367,7 +367,11 @@ std::vector<Express::VARP> StaticModule::onForward(const std::vector<Express::VA
         if (mResource->mUseContentInputs) {
             mSession->setNeedResize();
         }
-        mSession->resize();
+        auto code = mSession->resize();
+        if (NO_ERROR != code) {
+            FUNC_PRINT(code);
+            return {};
+        }
     } else {
         // Resize
         for (int i = 0; i < inputs.size(); ++i) {

+ 3 - 2
include/MNN/HalideRuntime.h

@@ -60,8 +60,9 @@ typedef enum halide_type_code_t
 {
     halide_type_int = 0,   //!< signed integers
     halide_type_uint = 1,  //!< unsigned integers
-    halide_type_float = 2, //!< floating point numbers
-    halide_type_handle = 3 //!< opaque pointer type (void *)
+    halide_type_float = 2, //!< IEEE floating point numbers
+    halide_type_handle = 3, //!< opaque pointer type (void *)
+    halide_type_bfloat = 4  //!< floating point numbers in the bfloat format
 } halide_type_code_t;
 
 // Note that while __attribute__ can go before or after the declaration,

+ 2 - 2
include/MNN/MNNDefine.h

@@ -68,7 +68,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
 #define STR_IMP(x) #x
 #define STR(x) STR_IMP(x)
 #define MNN_VERSION_MAJOR 2
-#define MNN_VERSION_MINOR 7
-#define MNN_VERSION_PATCH 2
+#define MNN_VERSION_MINOR 8
+#define MNN_VERSION_PATCH 0
 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
 #endif /* MNNDefine_h */

+ 8 - 7
include/MNN/expr/Executor.hpp

@@ -40,16 +40,17 @@ public:
 
     bool lazyEval = true;
     enum LazyMode {
-        // Don't compute at all until user needed.
-        LAZY_FULL,
-        
+        LAZY_FULL = 0,
         // Don't compute content until user needed.
-        LAZY_CONTENT
+        LAZY_CONTENT = 1 << 0,
+        
+        // Expr can only compute once, it can reduce the create cost of expr
+        LAZY_COMPUTE_ONCE = 1 << 1,
     };
-    LazyMode getLazyMode() const {
+    uint32_t getLazyMode() const {
         return mLazyMode;
     }
-    void setLazyComputeMode(LazyMode mode);
+    void setLazyComputeMode(uint32_t mode);
     void setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread);
     int getCurrentRuntimeStatus(RuntimeStatus statusEnum);
     enum GCFlag {
@@ -139,7 +140,7 @@ private:
     RuntimeInfo mRuntimeInfo;
     std::shared_ptr<DebugTools> mDebug;
     std::map<std::string, std::shared_ptr<SubGraph>> mSubGraph;
-    LazyMode mLazyMode = LAZY_FULL;
+    uint32_t mLazyMode = 0;
     std::shared_ptr<ExecutorAttr> mAttr;
     std::mutex mMutex;
 };

+ 40 - 35
package_scripts/win/build_lib.ps1

@@ -42,6 +42,7 @@ mkdir -p Debug\Dynamic\MD, Debug\Dynamic\MT, Debug\Static\MD, Debug\Static\MT, R
 popd
 
 $CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AVX512=ON"
+$ONLY_DYNAMIC_MT = $False
 if ($backends -ne $null) {
     Foreach ($backend in $backends.Split(",")) {
         if ($backend -eq "opencl") {
@@ -50,6 +51,7 @@ if ($backends -ne $null) {
             $CMAKE_ARGS = "$CMAKE_ARGS -DMNN_VULKAN=ON"
         } elseif ($backend -eq "cuda") {
             $CMAKE_ARGS = "$CMAKE_ARGS -DMNN_CUDA=ON"
+            $ONLY_DYNAMIC_MT = $True
         }
     }
 }
@@ -90,26 +92,28 @@ Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=
 cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Debug\Dynamic\MT
 rm MNN.*
 
-##### Debug/Dynamic/MD ####
-log "Debug/Dynamic/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=OFF .."
-cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Debug\Dynamic\MD
-rm MNN.*
+if ($ONLY_DYNAMIC_MT -eq $False) {
+    ##### Debug/Dynamic/MD ####
+    log "Debug/Dynamic/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=OFF .."
+    cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Debug\Dynamic\MD
+    rm MNN.*
 
-##### Debug/Static/MT ####
-log "Debug/Static/MT"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Debug\Static\MT
-rm MNN.*
+    ##### Debug/Static/MT ####
+    log "Debug/Static/MT"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Debug\Static\MT
+    rm MNN.*
 
-##### Debug/Static/MD ####
-log "Debug/Static/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Debug\Static\MD
-rm MNN.*
+    ##### Debug/Static/MD ####
+    log "Debug/Static/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Debug -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Debug\Static\MD
+    rm MNN.*
+}
 
 ##### Release/Dynamic/MT ####
 log "Release/Dynamic/MT"
@@ -118,23 +122,24 @@ Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_M
 cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MT
 rm MNN.*
 
-##### Release/Dynamic/MD ####
-log "Release/Dynamic/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF .."
-cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD
-rm MNN.*
-
-##### Release/Static/MT ####
-log "Release/Static/MT"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT
+if ($ONLY_DYNAMIC_MT -eq $False) {
+    ##### Release/Dynamic/MD ####
+    log "Release/Dynamic/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF .."
+    cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD
+    rm MNN.*
 
-##### Release/Static/MD ####
-log "Release/Static/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MD
+    ##### Release/Static/MT ####
+    log "Release/Static/MT"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT
 
+    ##### Release/Static/MD ####
+    log "Release/Static/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MD
+}
 popd

+ 21 - 17
package_scripts/win/build_lib_release.ps1

@@ -41,10 +41,13 @@ if ($cibuild) {
 popd
 
 $CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON  -DMNN_OPENCL=ON -DMNN_VULKAN=ON -DMNN_AVX512=ON -DMNN_LOW_MEMORY=ON"
+$ONLY_DYNAMIC_MT = $False
+
 if ($backends -ne $null) {
     Foreach ($backend in $backends.Split(",")) {
         if ($backend -eq "cuda") {
             $CMAKE_ARGS = "$CMAKE_ARGS -DMNN_CUDA=ON"
+            $ONLY_DYNAMIC_MT = $True
         }
     }
 }
@@ -96,23 +99,24 @@ if ($cibuild) {
     return
 }
 
-##### Release/Dynamic/MD ####
-log "Release/Dynamic/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF .."
-cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD
-rm MNN.*
+if ($ONLY_DYNAMIC_MT -eq $False) {
+    ##### Release/Dynamic/MD ####
+    log "Release/Dynamic/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF .."
+    cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD
+    rm MNN.*
 
-##### Release/Static/MT ####
-log "Release/Static/MT"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT
-
-##### Release/Static/MD ####
-log "Release/Static/MD"
-Remove-Item CMakeCache.txt -ErrorAction Ignore
-Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
-cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MD
+    ##### Release/Static/MT ####
+    log "Release/Static/MT"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT
 
+    ##### Release/Static/MD ####
+    log "Release/Static/MD"
+    Remove-Item CMakeCache.txt -ErrorAction Ignore
+    Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
+    cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MD
+}
 popd

+ 16 - 8
project/ios/MNN.xcodeproj/project.pbxproj

@@ -357,6 +357,8 @@
 		4DD1791B2684815A00B0098F /* ShapeSetDiff1D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4DD1791A2684815A00B0098F /* ShapeSetDiff1D.cpp */; };
 		4DD1793A2694076700B0098F /* MNNSoftmax.S in Sources */ = {isa = PBXBuildFile; fileRef = 4DD179392694076700B0098F /* MNNSoftmax.S */; };
 		4DD1793C2694078000B0098F /* MNNSoftmax.S in Sources */ = {isa = PBXBuildFile; fileRef = 4DD1793B2694078000B0098F /* MNNSoftmax.S */; };
+		4DDD8E102B1D70C1005065D1 /* MNNTranspose16Bit8x8.S in Sources */ = {isa = PBXBuildFile; fileRef = 4DDD8E0F2B1D70C1005065D1 /* MNNTranspose16Bit8x8.S */; };
+		4DDD8E122B1D70CC005065D1 /* MNNTranspose16Bit8x8.S in Sources */ = {isa = PBXBuildFile; fileRef = 4DDD8E112B1D70CC005065D1 /* MNNTranspose16Bit8x8.S */; };
 		4DDE2019263809920085AC8F /* CoreMLExecutorWrapper.mm in Sources */ = {isa = PBXBuildFile; fileRef = 4DDE2017263809920085AC8F /* CoreMLExecutorWrapper.mm */; };
 		4DDE201A263809920085AC8F /* CoreMLExecutorWrapper.h in Headers */ = {isa = PBXBuildFile; fileRef = 4DDE2018263809920085AC8F /* CoreMLExecutorWrapper.h */; };
 		4DE4E82C275E307B0016A916 /* cv in Headers */ = {isa = PBXBuildFile; fileRef = 4DE4E82B275E307B0016A916 /* cv */; settings = {ATTRIBUTES = (Public, ); }; };
@@ -1178,6 +1180,8 @@
 		4DD1791A2684815A00B0098F /* ShapeSetDiff1D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ShapeSetDiff1D.cpp; sourceTree = "<group>"; };
 		4DD179392694076700B0098F /* MNNSoftmax.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSoftmax.S; sourceTree = "<group>"; };
 		4DD1793B2694078000B0098F /* MNNSoftmax.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSoftmax.S; sourceTree = "<group>"; };
+		4DDD8E0F2B1D70C1005065D1 /* MNNTranspose16Bit8x8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNTranspose16Bit8x8.S; sourceTree = "<group>"; };
+		4DDD8E112B1D70CC005065D1 /* MNNTranspose16Bit8x8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNTranspose16Bit8x8.S; sourceTree = "<group>"; };
 		4DDE2017263809920085AC8F /* CoreMLExecutorWrapper.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CoreMLExecutorWrapper.mm; sourceTree = "<group>"; };
 		4DDE2018263809920085AC8F /* CoreMLExecutorWrapper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CoreMLExecutorWrapper.h; sourceTree = "<group>"; };
 		4DE4E82B275E307B0016A916 /* cv */ = {isa = PBXFileReference; lastKnownFileType = folder; name = cv; path = ../tools/cv/include/cv; sourceTree = "<group>"; };
@@ -2508,6 +2512,7 @@
 		92FF013A23AA0B4E00AC97F6 /* arm32 */ = {
 			isa = PBXGroup;
 			children = (
+				4DDD8E112B1D70CC005065D1 /* MNNTranspose16Bit8x8.S */,
 				95CE1DFE2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S */,
 				CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */,
 				CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */,
@@ -2589,6 +2594,7 @@
 		92FF017C23AA0B4E00AC97F6 /* arm64 */ = {
 			isa = PBXGroup;
 			children = (
+				4DDD8E0F2B1D70C1005065D1 /* MNNTranspose16Bit8x8.S */,
 				95CE1E002AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S */,
 				CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */,
 				CEE9B9582A3AA4D4006438F2 /* MNNBilinearSampleC8.S */,
@@ -3536,6 +3542,7 @@
 				92FF029A23AA0B5A00AC97F6 /* CPUQuantizedMaxPool.cpp in Sources */,
 				48F5881124DEA3F000C484A2 /* GeometryPooling3D.cpp in Sources */,
 				92FF042423AA0B7100AC97F6 /* ShapeROIPooling.cpp in Sources */,
+				4DDD8E122B1D70CC005065D1 /* MNNTranspose16Bit8x8.S in Sources */,
 				92FF033723AA0B5A00AC97F6 /* MNNConvDwF23MulTransUnit.S in Sources */,
 				4896D37A25FE2A6B00717702 /* MNNPackedMatMulRemainFP16.S in Sources */,
 				92FF043023AA0B7100AC97F6 /* ShapeQuantizedAvgPool.cpp in Sources */,
@@ -3597,6 +3604,7 @@
 				48FA474523AA127B00172C3B /* Executor.cpp in Sources */,
 				92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */,
 				48A8A61A21D101DE00C2B9A7 /* Matrix_CV.cpp in Sources */,
+				4DDD8E102B1D70C1005065D1 /* MNNTranspose16Bit8x8.S in Sources */,
 				489D7A8C2550FDC900AD896A /* MetalDeconvolution.mm in Sources */,
 				489D7AA62550FDC900AD896A /* MetalBackend.mm in Sources */,
 				92FF031823AA0B5A00AC97F6 /* MNNConvRunForUnitDepthWiseUint8.S in Sources */,
@@ -4019,7 +4027,7 @@
 				CODE_SIGN_STYLE = Automatic;
 				DEAD_CODE_STRIPPING = YES;
 				DEFINES_MODULE = YES;
-				DEVELOPMENT_TEAM = Q48UX93J22;
+				DEVELOPMENT_TEAM = 6G7464HHUS;
 				DYLIB_COMPATIBILITY_VERSION = 1;
 				DYLIB_CURRENT_VERSION = 1;
 				DYLIB_INSTALL_NAME_BASE = "@rpath";
@@ -4062,7 +4070,7 @@
 				METAL_LIBRARY_FILE_BASE = mnn;
 				ONLY_ACTIVE_ARCH = YES;
 				OTHER_CFLAGS = "";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3;
 				PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
 				PROVISIONING_PROFILE_SPECIFIER = "";
 				"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = "";
@@ -4083,7 +4091,7 @@
 				CODE_SIGN_STYLE = Automatic;
 				DEAD_CODE_STRIPPING = YES;
 				DEFINES_MODULE = YES;
-				DEVELOPMENT_TEAM = Q48UX93J22;
+				DEVELOPMENT_TEAM = 6G7464HHUS;
 				DYLIB_COMPATIBILITY_VERSION = 1;
 				DYLIB_CURRENT_VERSION = 1;
 				DYLIB_INSTALL_NAME_BASE = "@rpath";
@@ -4124,7 +4132,7 @@
 				MACH_O_TYPE = staticlib;
 				METAL_LIBRARY_FILE_BASE = mnn;
 				OTHER_CFLAGS = "";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3;
 				PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
 				PROVISIONING_PROFILE_SPECIFIER = "";
 				"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = "";
@@ -4143,7 +4151,7 @@
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
 				CODE_SIGN_STYLE = Automatic;
-				DEVELOPMENT_TEAM = Q48UX93J22;
+				DEVELOPMENT_TEAM = 6G7464HHUS;
 				GCC_ENABLE_CPP_EXCEPTIONS = NO;
 				GCC_ENABLE_CPP_RTTI = NO;
 				HEADER_SEARCH_PATHS = (
@@ -4156,7 +4164,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};
@@ -4168,7 +4176,7 @@
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
 				CODE_SIGN_STYLE = Automatic;
-				DEVELOPMENT_TEAM = Q48UX93J22;
+				DEVELOPMENT_TEAM = 6G7464HHUS;
 				GCC_ENABLE_CPP_EXCEPTIONS = NO;
 				GCC_ENABLE_CPP_RTTI = NO;
 				HEADER_SEARCH_PATHS = (
@@ -4181,7 +4189,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};

+ 22 - 0
pymnn/examples/MNNTrain/simple/grad_loss.py

@@ -0,0 +1,22 @@
+import MNN.numpy as np
+import MNN
+import sys
+nn = MNN.nn
+F = MNN.expr
+F.lazy_eval(True)
+F.set_lazy_mode(1)
+
+opt = MNN.optim.Grad()
+
+vars = F.load_as_dict(sys.argv[1])
+output = vars['loss']
+parameters = [vars['weight']]
+rgbdiff = F.placeholder(output.shape, output.data_format, output.dtype)
+rgbdiff.name = 'loss_diff'
+rgbdiff.write([1.0])
+rgbdiff.fix_as_const()
+
+parameters, grad = opt.grad([output], [rgbdiff], parameters)
+for i in range(0, len(parameters)):
+    grad[i].name = 'grad::' + parameters[i].name
+F.save(grad, sys.argv[2])

+ 43 - 0
pymnn/examples/MNNTrain/simple/make_solve_equation_graph.py

@@ -0,0 +1,43 @@
+import time
+import MNN.numpy as np
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+# open lazy evaluation for train
+F.lazy_eval(True)
+
+# month_pay=pow(rate/12+1, times)*(rate/12)*total/(pow(rate/12+1,times)-1)
+# Know month_pa, total, times, solve rate
+class Net(nn.Module):
+    def __init__(self):
+        super(Net, self).__init__()
+        one = np.array([0.001])
+        one.fix_as_trainable()
+        self.rate = one
+
+    def forward(self, times, total):
+        r12 = self.rate / 12.0
+        r12_1 = r12 + np.array([1.0])
+        total_rate = np.power(r12_1, times)
+        p0 = (total_rate * r12 * total) / (total_rate-np.array([1.0]))
+        return p0
+
+model = Net()
+opt = MNN.optim.SGD(model, 0.0000000001, 0.9, 0.0005)
+
+times = np.array([60.0])
+month_diff = np.array([1.0])
+month_diff.fix_as_placeholder()
+month_diff.name = "month_diff"
+total = np.array([630000.0])
+month_comp = model.forward(times, total)
+rates, rates_grad  = opt.grad([month_comp], [month_diff], [model.rate])
+lr_rate = np.array([0.0000001])
+lr_rate.fix_as_placeholder()
+lr_rate.name = "lr_rate"
+
+rates, rates_update = opt.get_update_graph(rates, rates_grad, [lr_rate])
+opt.save_graph("update.mnn", [], rates, rates_update)
+
+

+ 8 - 0
pymnn/pip_package/MNN/nn/__init__.py

@@ -52,3 +52,11 @@ class Module(_nn._Module):
             else:
                 self._vars[name] = value
                 self._add_parameter(value)
+
+
+class EmptyModule(_nn._Module):
+    def __init(self):
+        super(EmptyModule, self).__init__()
+    def forward(self):
+        return None
+dummy = EmptyModule()

+ 55 - 23
pymnn/pip_package/build_deps.py

@@ -19,23 +19,47 @@ IS_DARWIN = (platform.system() == 'Darwin')
 IS_LINUX = (platform.system() == 'Linux')
 BUILD_DIR = 'pymnn_build' # avoid overwrite temporary product when build pymnn
 
-USE_TRT=False
-if len(sys.argv) > 1 and sys.argv[1] == '-trt':
-    USE_TRT=True
+USE_TRT      = False
+USE_CUDA     = False
+USE_CUDA_TUNE= False
+USE_OPENCL   = False
+USE_VULKAN   = False
+USE_TORCH    = False
+USE_INTERNAL = False
+USE_RENDER   = False
+USE_SSE      = True
 
-USE_CUDA=False
-if len(sys.argv) > 1 and sys.argv[1] == '-cuda':
-    USE_CUDA=True
+if len(sys.argv) > 1 and sys.argv[1] != None:
+    if "trt" in sys.argv[1]:
+        USE_TRT = True
+    if "cuda" in sys.argv[1]:
+        USE_CUDA = True
+    if "cuda_tune" in sys.argv[1]:
+        USE_CUDA_TUNE = True
+    if "opencl" in sys.argv[1]:
+        USE_OPENCL = True
+    if "vulkan" in sys.argv[1]:
+        USE_VULKAN = True
+    if "torch" in sys.argv[1]:
+        USE_TORCH = True
+    if "internal" in sys.argv[1]:
+        USE_INTERNAL = True
+    if "render" in sys.argv[1]:
+        USE_RENDER = True
+    if "no_sse" in sys.argv[1]:
+        USE_SSE = False
 
-def build_deps():
-    if os.path.isdir('../../schema/private'):
-        IS_INTERNAL_BUILD = args.internal
-        # public not build torch
-        IS_BUILD_TORCH = args.torch
-    else:
-        IS_INTERNAL_BUILD = False
-        IS_BUILD_TORCH = False
+print ("USE_INTERNAL:", USE_INTERNAL)
+print ("USE_TRT:", USE_TRT)
+print ("USE_CUDA:", USE_CUDA)
+if USE_CUDA_TUNE:
+    print ("USE_CUDA_TUNE, please note: this function only support Ampere Arch now!")
+print ("USE_OPENCL:", USE_OPENCL)
+print ("USE_VULKAN:", USE_VULKAN)
+print ("USE_RENDER:", USE_RENDER)
+print ("USE_SSE:", USE_SSE)
 
+def build_deps():
     """ build depency """
     root_dir = os.path.dirname(os.path.dirname(os.getcwd()))
     #build_main_project
@@ -45,8 +69,12 @@ def build_deps():
     os.makedirs(cmake_build_dir)
     os.chdir(cmake_build_dir)
     extra_opts = '-DMNN_LOW_MEMORY=ON'
-    extra_opts += ' -DMNN_VULKAN=ON -DMNN_VULKAN_IMAGE=OFF'
-    extra_opts += ' -DMNN_OPENCL=ON'
+    if USE_RENDER:
+        extra_opts += ' -DMNN_SUPPORT_RENDER=ON'
+    if USE_VULKAN:
+        extra_opts += ' -DMNN_VULKAN=ON -DMNN_VULKAN_IMAGE=OFF'
+    if USE_OPENCL:
+        extra_opts += ' -DMNN_OPENCL=ON'
     if IS_WINDOWS:
         os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\
             -DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\
@@ -54,21 +82,25 @@ def build_deps():
     elif IS_LINUX:
         extra_opts += '-DMNN_TENSORRT=ON \
         -DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' '
-        extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' '
-        extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' '
-        extra_opts += ' -DMNN_CUDA=ON ' if USE_CUDA else ' '
+        extra_opts += ' -DMNN_INTERNAL=ON ' if USE_INTERNAL else ' '
+        extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' '
+        if USE_CUDA:
+            extra_opts += ' -DMNN_CUDA=ON '
+            if USE_CUDA_TUNE:
+                extra_opts += ' -DMNN_CUDA_TUNE_PARAM=ON '
+        extra_opts += ' ' if USE_SSE else ' -DMNN_USE_SSE=OFF '
         os.system('cmake ' + extra_opts +
             '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
-            -DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF .. && make MNN MNNTrain MNNConvertDeps -j4')
+            -DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF .. && make MNN MNNTrain MNNConvertDeps -j32')
     else:
-        extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' '
-        extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' '
+        extra_opts += ' -DMNN_INTERNAL=ON ' if USE_INTERNAL else ' '
+        extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' '
         print(extra_opts)
         os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\
             -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
-            .. && make MNN MNNTrain MNNConvertDeps -j4')
+            .. && make MNN MNNTrain MNNConvertDeps -j32')
 ################################################################################
 # Building dependent libraries
 ################################################################################

+ 53 - 25
pymnn/pip_package/setup.py

@@ -26,6 +26,8 @@ parser.add_argument('--version', dest='version', type=str, default=get_version()
                     help='MNN dist version')
 parser.add_argument('--serving', dest='serving', action='store_true', default=False,
                     help='build for internal serving, default False')
+parser.add_argument('--deps', dest='deps', type=str, required=False,
+                    help='MNN library deps')
 parser.add_argument('--env', dest='env', type=str, required=False,
                     help='build environment, e.g. :daily/pre/production')
 args, unknown = parser.parse_known_args()
@@ -58,21 +60,46 @@ def report(*args):
     print(*args)
 
 package_name = 'MNN'
-USE_TRT=check_env_flag('USE_TRT')
-USE_CUDA = check_env_flag("USE_CUDA")
-IS_INTERNAL_BUILD = False
-
-print ("USE_TRT ", USE_TRT)
-print("USE_CUDA:", USE_CUDA)
-
-if os.path.isdir('../../schema/private'):
-    IS_INTERNAL_BUILD = args.serving
-    if USE_TRT:
-        print("Build Internal NNN with TRT")
-        package_name = 'MNN_Internal_TRT'
-    else:
-        print("Build Internal NNN")
-        package_name = 'MNN_Internal'
+USE_INTERNAL = False
+USE_TRT      = False
+USE_CUDA     = False
+USE_OPENCL   = False
+USE_VULKAN   = False
+USE_RENDER   = False
+
+if args.deps != None:
+    if "trt" in args.deps:
+        USE_TRT = True
+    if "cuda" in args.deps:
+        USE_CUDA = True
+    if "opencl" in args.deps:
+        USE_OPENCL = True
+    if "vulkan" in args.deps:
+        USE_VULKAN = True
+    if "internal" in args.deps:
+        USE_INTERNAL = True
+    if "render" in args.deps:
+        USE_RENDER = True
+
+print ("USE_INTERNAL:", USE_INTERNAL)
+print ("USE_TRT:", USE_TRT)
+print ("USE_CUDA:", USE_CUDA)
+print ("USE_OPENCL:", USE_OPENCL)
+print ("USE_VULKAN:", USE_VULKAN)
+print ("USE_RENDER:", USE_RENDER)
+
+if USE_INTERNAL:
+    package_name += '_Internal'
+if USE_TRT:
+    package_name += '_TRT'
+if USE_CUDA:
+    package_name += '_CUDA'
+if USE_VULKAN:
+    package_name += '_VULKAN'
+if USE_OPENCL:
+    package_name += '_OPENCL'
+if USE_RENDER:
+    package_name += '_RENDER'
 
 print ('Building with python wheel with package name ', package_name)
 
@@ -137,7 +164,7 @@ def configure_extension_build():
         if check_env_flag('WERROR'):
             extra_compile_args.append('-Werror')
     extra_compile_args += ['-DPYMNN_EXPR_API', '-DPYMNN_NUMPY_USABLE', '-DPYMNN_OPENCV_API']
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         extra_compile_args += ['-DPYMNN_INTERNAL_SERVING']
         if args.env == 'daily':
             extra_compile_args += ['-DPYMNN_INTERNAL_SERVING_DAILY']
@@ -154,13 +181,13 @@ def configure_extension_build():
         engine_library_dirs += ['/usr/local/cuda/lib64/']
 
     # Logging is enabled on Linux. Add the dependencies.
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         engine_library_dirs += ['/usr/include/curl/']
 
     print(engine_library_dirs)
     engine_link_args = []
     engine_sources = [os.path.join(root_dir, "pymnn", "src", "MNN.cc")]
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         engine_sources += [os.path.join(root_dir, "pymnn", "src", "internal", "monitor_service.cc")]
         engine_sources += [os.path.join(root_dir, "pymnn", "src", "internal", "verify_service.cc")]
         engine_sources += [os.path.join(root_dir, "pymnn", "src", "internal", "http_util.cc")]
@@ -180,18 +207,19 @@ def configure_extension_build():
     engine_include_dirs += [os.path.join(root_dir, "schema", "current")]
     engine_include_dirs += [os.path.join(root_dir, "3rd_party",\
                                           "flatbuffers", "include")]
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         engine_include_dirs += [os.path.join(root_dir, "3rd_party", "rapidjson")]
     # cv include
     engine_include_dirs += [os.path.join(root_dir, "tools", "cv", "include")]
     engine_include_dirs += [np.get_include()]
 
+    lib_files = []
     trt_depend = ['-lTRT_CUDA_PLUGIN', '-lnvinfer', '-lnvparsers', '-lnvinfer_plugin', '-lcudart']
     cuda_depend = ['-lMNN_Cuda_Main']
     engine_depend = ['-lMNN']
 
     # enable logging & model authentication on linux.
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         engine_depend += ['-lcurl', '-lssl', '-lcrypto']
 
     if USE_TRT:
@@ -199,6 +227,7 @@ def configure_extension_build():
 
     if USE_CUDA:
         engine_depend += cuda_depend
+        lib_files += [('lib', [os.path.join(root_dir, BUILD_DIR, "source", "backend", "cuda", "libMNN_Cuda_Main.so")])]
 
     tools_compile_args = []
     tools_libraries = []
@@ -210,7 +239,6 @@ def configure_extension_build():
     tools_library_dirs += [os.path.join(root_dir, BUILD_DIR, "3rd_party", "protobuf", "cmake")]
 
     # add libTorch dependency
-    lib_files = []
     torch_lib = None
     cmakecache = os.path.join(root_dir, BUILD_DIR, 'CMakeCache.txt')
     for line in open(cmakecache, 'rt').readlines():
@@ -224,11 +252,11 @@ def configure_extension_build():
         elif IS_DARWIN:
             torch_path = os.path.dirname(torch_lib)
             tools_library_dirs += [torch_lib]
-            lib_files = [('lib', [os.path.join(torch_lib, 'libtorch.dylib'), os.path.join(torch_lib, 'libtorch_cpu.dylib'),
+            lib_files += [('lib', [os.path.join(torch_lib, 'libtorch.dylib'), os.path.join(torch_lib, 'libtorch_cpu.dylib'),
                                   os.path.join(torch_lib, 'libc10.dylib')]),
                          ('.dylibs', [os.path.join(torch_lib, 'libiomp5.dylib')])]
             '''
-            lib_files = [('lib', [os.path.join(torch_lib, 'libtorch.dylib'), os.path.join(torch_lib, 'libtorch_cpu.dylib'),
+            lib_files += [('lib', [os.path.join(torch_lib, 'libtorch.dylib'), os.path.join(torch_lib, 'libtorch_cpu.dylib'),
                                   os.path.join(torch_lib, 'libc10.dylib')]),
                          ('.dylibs', [os.path.join(torch_path, '.dylibs', 'libiomp5.dylib')])]
             '''
@@ -236,7 +264,7 @@ def configure_extension_build():
         # Note: TensorRT-5.1.5.0/lib should be set in $LIBRARY_PATH of the build system.
         tools_library_dirs += ['/usr/local/cuda/lib64/']
 
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         tools_library_dirs += ['/usr/include/curl/']
 
     tools_link_args = []
@@ -268,7 +296,7 @@ def configure_extension_build():
     tools_include_dirs += [np.get_include()]
 
     # enable logging and model authentication on linux.
-    if IS_LINUX and IS_INTERNAL_BUILD:
+    if IS_LINUX and USE_INTERNAL:
         tools_depend += ['-lcurl', '-lssl', '-lcrypto']
 
     if USE_TRT:

+ 25 - 2
pymnn/src/expr.h

@@ -1021,6 +1021,27 @@ static PyObject* PyMNNExpr_lazy_eval(PyObject *self, PyObject *args) {
     Py_RETURN_NONE;
 }
 
+static PyObject* PyMNNExpr_set_lazy_mode(PyObject *self, PyObject *args) {
+    int lazy = 0;
+    if (!PyArg_ParseTuple(args, "i", &lazy)) {
+        return NULL;
+    }
+    ExecutorScope::Current()->setLazyComputeMode((Executor::LazyMode)lazy);
+    Py_RETURN_NONE;
+}
+static PyObject* PyMNNExpr_set_global_executor_config(PyObject *self, PyObject *args) {
+    int numberThread, backendType, precisionType;
+    if (!PyArg_ParseTuple(args, "iii", &backendType, &precisionType, &numberThread)) {
+        Py_RETURN_NONE;
+    }
+
+    auto exe = ExecutorScope::Current();
+    BackendConfig config;
+    config.precision = (BackendConfig::PrecisionMode)precisionType;
+    exe->setGlobalExecutorConfig((MNNForwardType)backendType, config, numberThread);
+    Py_RETURN_NONE;
+}
+
 def_unary(Expr,
     sign, Express::_Sign,
     abs, Express::_Abs,
@@ -1692,13 +1713,15 @@ static PyMethodDef PyMNNExpr_methods[] = {
     )
     register_methods(Expr,
         // Var methods
-        set_thread_number, "set threan number of expr.",
+        set_thread_number, "set thread number of expr.",
         load_as_list, "load file as var list.",
         save, "save vars to file.",
         load_as_dict, "load file as var dict.",
         get_inputs_and_outputs, "get input and output of var dict.",
         gc, "do gc full or part.",
-        lazy_eval, "expr do lazy evaluation or not."
+        lazy_eval, "expr do lazy evaluation or not.",
+        set_lazy_mode, "set lazy compute mode, content: 0 or full: 1.",
+        set_global_executor_config, "set global executor config for expr."
     )
     register_methods(Expr,
         // unary expr

+ 35 - 0
pymnn/src/nn.h

@@ -18,6 +18,7 @@ def_class_methods(_Module,
     forward, "forward",
     onForward, "onForward",
     set_name, "set name",
+    get_info, "get module info",
     train, "set is_training",
     load_parameters, "load parameters",
     clear_cache, "clear cache",
@@ -143,6 +144,40 @@ static PyObject* PyMNN_Module_forward(PyMNN_Module *self, PyObject *args) {
     }
     PyMNN_ERROR("PyMNN_Module_forward: args must be Var/[Var].");
 }
+static PyObject* PyMNN_Module_get_info(PyMNN_Module *self, PyObject *args) {
+    auto m = (*(self->ptr));
+    auto info = m->getInfo();
+    if (nullptr == info) {
+        PyMNN_ERROR("The module can't get info");
+        Py_RETURN_NONE;
+    }
+    auto res = PyDict_New();
+    PyDict_SetItemString(res, "version", char2Object(info->version.c_str()));
+    {
+        auto names = PyList_New(info->inputNames.size());
+        for (int i=0; i<info->inputNames.size(); ++i) {
+            PyList_SetItem(names, i, char2Object(info->inputNames[i].c_str()));
+        }
+        PyDict_SetItemString(res, "inputNames", names);
+    }
+    {
+        auto names = PyList_New(info->outputNames.size());
+        for (int i=0; i<info->outputNames.size(); ++i) {
+            PyList_SetItem(names, i, char2Object(info->outputNames[i].c_str()));
+        }
+        PyDict_SetItemString(res, "outputNames", names);
+    }
+    {
+        auto inputs = PyList_New(info->inputs.size());
+        for (int i=0; i<info->inputs.size(); ++i) {
+            auto& v = info->inputs[i];
+            auto var = MNN::Express::_Input(v.dim, v.order, v.type);
+            PyList_SetItem(inputs, i, toPyObj(var));
+        }
+        PyDict_SetItemString(res, "inputs", inputs);
+    }
+    return res;
+}
 static PyObject* PyMNN_Module_onForward(PyMNN_Module *self, PyObject *args) {
     PyObject *inputs;
     if (!PyArg_ParseTuple(args, "O", &inputs)) {

+ 76 - 1
pymnn/src/optim.h

@@ -1,4 +1,5 @@
 #include "util.h"
+#include "OpGrad.hpp"
 
 // Optim Module
 def_enum(Regularization_Method, ParameterOptimizer::RegularizationMethod,
@@ -18,7 +19,10 @@ def_class_getset(
     regularization_method, 1
 )
 def_class_methods(Optimizer,
-    step, "Optimizer step"
+    step, "Optimizer step",
+    grad, "Grad for variables",
+    get_update_graph, "Get Update Graph for parameters",
+    save_graph, "Save Update Graph to MNN File"
 )
 def_class_end(Optimizer, ParameterOptimizer)
 // impl
@@ -104,6 +108,70 @@ static int PyMNNOptimizer_setregularization_method(PyMNNOptimizer *self, PyObjec
     }
     return 0;  
 }
+
+static PyObject* _makeTupleFromPairVector(const std::pair<std::vector<Express::VARP>, std::vector<Express::VARP>>& values) {
+    PyObject* obj0 = PyList_New(values.first.size());
+    for (int i = 0; i < values.first.size(); i++) {
+        PyList_SetItem(obj0, i, toPyObj(values.first[i]));
+    }
+    PyObject* obj1 = PyList_New(values.second.size());
+    for (int i = 0; i < values.second.size(); i++) {
+        PyList_SetItem(obj1, i, toPyObj(values.second[i]));
+    }
+    PyObject* obj = PyTuple_New(2);
+    PyTuple_SetItem(obj, 0, obj0);
+    PyTuple_SetItem(obj, 1, obj1);
+    return obj;
+}
+static PyObject* PyMNNOptimizer_grad(PyMNNOptimizer *self, PyObject *args) {
+    PyObject* outputs;
+    PyObject* outputDiffs;
+    PyObject* parameters;
+    if (PyArg_ParseTuple(args, "OOO", &outputs, &outputDiffs, &parameters)) {
+        if (isVars(outputs) && isVals(outputDiffs) && isVars(parameters)) {
+            auto values = OpGrad::gradCommon(toVars(outputs), toVars(outputDiffs), toVars(parameters));
+            return _makeTupleFromPairVector(values);
+        }
+    }
+    PyMNN_ERROR("grad require args: ([Var](outputs),[Var](output Diff), [Var](parameters))");
+    return Py_None;
+}
+static PyObject* PyMNNOptimizer_get_update_graph(PyMNNOptimizer *self, PyObject *args) {
+    PyObject* parameter;
+    PyObject* parameterGrad;
+    PyObject* learningRate;
+    if (PyArg_ParseTuple(args, "OOO", &parameter, &parameterGrad, &learningRate)) {
+        if (isVars(parameter) && isVals(parameterGrad) && isVars(learningRate)) {
+            if (self->ptr) {
+                auto p = toVars(parameter);
+                auto pd = toVars(parameterGrad);
+                auto lr = toVars(learningRate);
+                auto values = static_cast<ParameterOptimizer*>(self->ptr)->makeParameterUpdateGraphByGrad(p, pd, lr);
+                return _makeTupleFromPairVector(values);
+            }
+        }
+    }
+    PyMNN_ERROR("get_update_graph require args: ([Var](parameter),[Var](parameter grad), [Var](learningRate))");
+    return Py_None;
+}
+static PyObject* PyMNNOptimizer_save_graph(PyMNNOptimizer *self, PyObject *args) {
+    const char* modelFile      = NULL;
+    PyObject* outputs;
+    PyObject* parameter;
+    PyObject* parameterUpdate;
+    if (PyArg_ParseTuple(args, "sOOO", &modelFile, &outputs, &parameter, &parameterUpdate)) {
+        if (isVars(parameter) && isVals(parameterUpdate) && isVars(outputs)) {
+            auto o = toVars(outputs);
+            auto p = toVars(parameter);
+            auto pu = toVars(parameterUpdate);
+            ParameterOptimizer::makeLoopModel(modelFile, o, std::make_pair(p, pu));
+            return Py_None;
+        }
+    }
+    PyMNN_ERROR("save_graph require args: ([string](outputPath),[Var](outputs), [Var](parameter),  [Var](parameterUpdate))");
+    return Py_None;
+}
+
 // PyMNNOptimizer methods impl
 static PyObject* PyMNNOptimizer_step(PyMNNOptimizer *self, PyObject *args) {
     PyObject *loss;
@@ -112,6 +180,12 @@ static PyObject* PyMNNOptimizer_step(PyMNNOptimizer *self, PyObject *args) {
     }
     return toPyObj(self->ptr->step(toVar(loss)));
 }
+static PyObject* PyMNNOptim_Grad(PyObject *self, PyObject *args, PyObject *kwargs) {
+    float learning_rate = 1e-3, momentum = 0.9, weight_decay = 0.0;
+    std::shared_ptr<Module> m;
+    return toPyObj(ParameterOptimizer::createSGD(m, learning_rate, momentum,
+                                                 weight_decay, RegularizationMethod::L2));
+}
 static PyObject* PyMNNOptim_SGD(PyObject *self, PyObject *args, PyObject *kwargs) {
     PyObject *module = nullptr, *method = nullptr /* L2 */;
     float learning_rate = 1e-3, momentum = 0.9, weight_decay = 0.0;
@@ -141,6 +215,7 @@ static PyObject* PyMNNOptim_ADAM(PyObject *self, PyObject *args, PyObject *kwarg
 }
 static PyMethodDef PyMNNOptim_methods[] = {
     register_methods_kw(Optim,
+        Grad, "Grad Only",
         SGD, "SGD Optimizer",
         ADAM, "ADAM Optimizer"
     )

+ 292 - 0
schema/current/CaffeOp_generated.h

@@ -76,6 +76,12 @@ struct BatchNormT;
 struct Scale;
 struct ScaleT;
 
+struct QuantizeLinear;
+struct QuantizeLinearT;
+
+struct DequantizeLinear;
+struct DequantizeLinearT;
+
 struct Eltwise;
 struct EltwiseT;
 
@@ -159,6 +165,10 @@ inline const flatbuffers::TypeTable *BatchNormTypeTable();
 
 inline const flatbuffers::TypeTable *ScaleTypeTable();
 
+inline const flatbuffers::TypeTable *QuantizeLinearTypeTable();
+
+inline const flatbuffers::TypeTable *DequantizeLinearTypeTable();
+
 inline const flatbuffers::TypeTable *EltwiseTypeTable();
 
 inline const flatbuffers::TypeTable *FlattenTypeTable();
@@ -2912,6 +2922,180 @@ inline flatbuffers::Offset<Scale> CreateScale(
 
 flatbuffers::Offset<Scale> CreateScale(flatbuffers::FlatBufferBuilder &_fbb, const ScaleT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct QuantizeLinearT : public flatbuffers::NativeTable {
+  typedef QuantizeLinear TableType;
+  int32_t scaleSize;
+  int32_t scaleAxis;
+  std::vector<float> scaleData;
+  std::vector<int8_t> zeroPointData;
+  QuantizeLinearT()
+      : scaleSize(0),
+        scaleAxis(0) {
+  }
+};
+
+struct QuantizeLinear FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef QuantizeLinearT NativeTableType;
+  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+    return QuantizeLinearTypeTable();
+  }
+  int32_t scaleSize() const {
+    return GetField<int32_t>(4, 0);
+  }
+  int32_t scaleAxis() const {
+    return GetField<int32_t>(6, 0);
+  }
+  const flatbuffers::Vector<float> *scaleData() const {
+    return GetPointer<const flatbuffers::Vector<float> *>(8);
+  }
+  const flatbuffers::Vector<int8_t> *zeroPointData() const {
+    return GetPointer<const flatbuffers::Vector<int8_t> *>(10);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, 4) &&
+           VerifyField<int32_t>(verifier, 6) &&
+           VerifyOffset(verifier, 8) &&
+           verifier.VerifyVector(scaleData()) &&
+           VerifyOffset(verifier, 10) &&
+           verifier.VerifyVector(zeroPointData()) &&
+           verifier.EndTable();
+  }
+  QuantizeLinearT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(QuantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<QuantizeLinear> Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct QuantizeLinearBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_scaleSize(int32_t scaleSize) {
+    fbb_.AddElement<int32_t>(4, scaleSize, 0);
+  }
+  void add_scaleAxis(int32_t scaleAxis) {
+    fbb_.AddElement<int32_t>(6, scaleAxis, 0);
+  }
+  void add_scaleData(flatbuffers::Offset<flatbuffers::Vector<float>> scaleData) {
+    fbb_.AddOffset(8, scaleData);
+  }
+  void add_zeroPointData(flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData) {
+    fbb_.AddOffset(10, zeroPointData);
+  }
+  explicit QuantizeLinearBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  QuantizeLinearBuilder &operator=(const QuantizeLinearBuilder &);
+  flatbuffers::Offset<QuantizeLinear> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<QuantizeLinear>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    int32_t scaleSize = 0,
+    int32_t scaleAxis = 0,
+    flatbuffers::Offset<flatbuffers::Vector<float>> scaleData = 0,
+    flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData = 0) {
+  QuantizeLinearBuilder builder_(_fbb);
+  builder_.add_zeroPointData(zeroPointData);
+  builder_.add_scaleData(scaleData);
+  builder_.add_scaleAxis(scaleAxis);
+  builder_.add_scaleSize(scaleSize);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct DequantizeLinearT : public flatbuffers::NativeTable {
+  typedef DequantizeLinear TableType;
+  int32_t scaleSize;
+  int32_t scaleAxis;
+  std::vector<float> scaleData;
+  std::vector<int8_t> zeroPointData;
+  DequantizeLinearT()
+      : scaleSize(0),
+        scaleAxis(0) {
+  }
+};
+
+struct DequantizeLinear FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef DequantizeLinearT NativeTableType;
+  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+    return DequantizeLinearTypeTable();
+  }
+  int32_t scaleSize() const {
+    return GetField<int32_t>(4, 0);
+  }
+  int32_t scaleAxis() const {
+    return GetField<int32_t>(6, 0);
+  }
+  const flatbuffers::Vector<float> *scaleData() const {
+    return GetPointer<const flatbuffers::Vector<float> *>(8);
+  }
+  const flatbuffers::Vector<int8_t> *zeroPointData() const {
+    return GetPointer<const flatbuffers::Vector<int8_t> *>(10);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<int32_t>(verifier, 4) &&
+           VerifyField<int32_t>(verifier, 6) &&
+           VerifyOffset(verifier, 8) &&
+           verifier.VerifyVector(scaleData()) &&
+           VerifyOffset(verifier, 10) &&
+           verifier.VerifyVector(zeroPointData()) &&
+           verifier.EndTable();
+  }
+  DequantizeLinearT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(DequantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<DequantizeLinear> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct DequantizeLinearBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_scaleSize(int32_t scaleSize) {
+    fbb_.AddElement<int32_t>(4, scaleSize, 0);
+  }
+  void add_scaleAxis(int32_t scaleAxis) {
+    fbb_.AddElement<int32_t>(6, scaleAxis, 0);
+  }
+  void add_scaleData(flatbuffers::Offset<flatbuffers::Vector<float>> scaleData) {
+    fbb_.AddOffset(8, scaleData);
+  }
+  void add_zeroPointData(flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData) {
+    fbb_.AddOffset(10, zeroPointData);
+  }
+  explicit DequantizeLinearBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  DequantizeLinearBuilder &operator=(const DequantizeLinearBuilder &);
+  flatbuffers::Offset<DequantizeLinear> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<DequantizeLinear>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    int32_t scaleSize = 0,
+    int32_t scaleAxis = 0,
+    flatbuffers::Offset<flatbuffers::Vector<float>> scaleData = 0,
+    flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData = 0) {
+  DequantizeLinearBuilder builder_(_fbb);
+  builder_.add_zeroPointData(zeroPointData);
+  builder_.add_scaleData(scaleData);
+  builder_.add_scaleAxis(scaleAxis);
+  builder_.add_scaleSize(scaleSize);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct EltwiseT : public flatbuffers::NativeTable {
   typedef Eltwise TableType;
   EltwiseType type;
@@ -5158,6 +5342,76 @@ inline flatbuffers::Offset<Scale> CreateScale(flatbuffers::FlatBufferBuilder &_f
       _external);
 }
 
+inline QuantizeLinearT *QuantizeLinear::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new QuantizeLinearT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void QuantizeLinear::UnPackTo(QuantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = scaleSize(); _o->scaleSize = _e; };
+  { auto _e = scaleAxis(); _o->scaleAxis = _e; };
+  { auto _e = scaleData(); if (_e) { _o->scaleData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scaleData[_i] = _e->Get(_i); } } };
+  { auto _e = zeroPointData(); if (_e) { _o->zeroPointData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zeroPointData[_i] = _e->Get(_i); } } };
+}
+
+inline flatbuffers::Offset<QuantizeLinear> QuantizeLinear::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateQuantizeLinear(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizeLinearT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _scaleSize = _o->scaleSize;
+  auto _scaleAxis = _o->scaleAxis;
+  auto _scaleData = _o->scaleData.size() ? _fbb.CreateVector(_o->scaleData) : 0;
+  auto _zeroPointData = _o->zeroPointData.size() ? _fbb.CreateVector(_o->zeroPointData) : 0;
+  return MNN::CreateQuantizeLinear(
+      _fbb,
+      _scaleSize,
+      _scaleAxis,
+      _scaleData,
+      _zeroPointData);
+}
+
+inline DequantizeLinearT *DequantizeLinear::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new DequantizeLinearT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void DequantizeLinear::UnPackTo(DequantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = scaleSize(); _o->scaleSize = _e; };
+  { auto _e = scaleAxis(); _o->scaleAxis = _e; };
+  { auto _e = scaleData(); if (_e) { _o->scaleData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scaleData[_i] = _e->Get(_i); } } };
+  { auto _e = zeroPointData(); if (_e) { _o->zeroPointData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zeroPointData[_i] = _e->Get(_i); } } };
+}
+
+inline flatbuffers::Offset<DequantizeLinear> DequantizeLinear::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateDequantizeLinear(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DequantizeLinearT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _scaleSize = _o->scaleSize;
+  auto _scaleAxis = _o->scaleAxis;
+  auto _scaleData = _o->scaleData.size() ? _fbb.CreateVector(_o->scaleData) : 0;
+  auto _zeroPointData = _o->zeroPointData.size() ? _fbb.CreateVector(_o->zeroPointData) : 0;
+  return MNN::CreateDequantizeLinear(
+      _fbb,
+      _scaleSize,
+      _scaleAxis,
+      _scaleData,
+      _zeroPointData);
+}
+
 inline EltwiseT *Eltwise::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new EltwiseT();
   UnPackTo(_o, _resolver);
@@ -6394,6 +6648,44 @@ inline const flatbuffers::TypeTable *ScaleTypeTable() {
   return &tt;
 }
 
+inline const flatbuffers::TypeTable *QuantizeLinearTypeTable() {
+  static const flatbuffers::TypeCode type_codes[] = {
+    { flatbuffers::ET_INT, 0, -1 },
+    { flatbuffers::ET_INT, 0, -1 },
+    { flatbuffers::ET_FLOAT, 1, -1 },
+    { flatbuffers::ET_CHAR, 1, -1 }
+  };
+  static const char * const names[] = {
+    "scaleSize",
+    "scaleAxis",
+    "scaleData",
+    "zeroPointData"
+  };
+  static const flatbuffers::TypeTable tt = {
+    flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names
+  };
+  return &tt;
+}
+
+inline const flatbuffers::TypeTable *DequantizeLinearTypeTable() {
+  static const flatbuffers::TypeCode type_codes[] = {
+    { flatbuffers::ET_INT, 0, -1 },
+    { flatbuffers::ET_INT, 0, -1 },
+    { flatbuffers::ET_FLOAT, 1, -1 },
+    { flatbuffers::ET_CHAR, 1, -1 }
+  };
+  static const char * const names[] = {
+    "scaleSize",
+    "scaleAxis",
+    "scaleData",
+    "zeroPointData"
+  };
+  static const flatbuffers::TypeTable tt = {
+    flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names
+  };
+  return &tt;
+}
+
 inline const flatbuffers::TypeTable *EltwiseTypeTable() {
   static const flatbuffers::TypeCode type_codes[] = {
     { flatbuffers::ET_CHAR, 0, 0 },

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 125 - 25
schema/current/MNN_generated.h


+ 394 - 0
schema/current/TrainInfo_generated.h

@@ -0,0 +1,394 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TRAININFO_MNNTRAIN_H_
+#define FLATBUFFERS_GENERATED_TRAININFO_MNNTRAIN_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace MNNTrain {
+
+struct OpInfo;
+struct OpInfoT;
+
+struct KV;
+struct KVT;
+
+struct TrainInfo;
+struct TrainInfoT;
+
+inline const flatbuffers::TypeTable *OpInfoTypeTable();
+
+inline const flatbuffers::TypeTable *KVTypeTable();
+
+inline const flatbuffers::TypeTable *TrainInfoTypeTable();
+
+struct OpInfoT : public flatbuffers::NativeTable {
+  typedef OpInfo TableType;
+  std::string op;
+  std::string weight;
+  std::string bias;
+  OpInfoT() {
+  }
+};
+
+struct OpInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef OpInfoT NativeTableType;
+  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+    return OpInfoTypeTable();
+  }
+  const flatbuffers::String *op() const {
+    return GetPointer<const flatbuffers::String *>(4);
+  }
+  const flatbuffers::String *weight() const {
+    return GetPointer<const flatbuffers::String *>(6);
+  }
+  const flatbuffers::String *bias() const {
+    return GetPointer<const flatbuffers::String *>(8);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, 4) &&
+           verifier.VerifyString(op()) &&
+           VerifyOffset(verifier, 6) &&
+           verifier.VerifyString(weight()) &&
+           VerifyOffset(verifier, 8) &&
+           verifier.VerifyString(bias()) &&
+           verifier.EndTable();
+  }
+  OpInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(OpInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<OpInfo> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OpInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct OpInfoBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_op(flatbuffers::Offset<flatbuffers::String> op) {
+    fbb_.AddOffset(4, op);
+  }
+  void add_weight(flatbuffers::Offset<flatbuffers::String> weight) {
+    fbb_.AddOffset(6, weight);
+  }
+  void add_bias(flatbuffers::Offset<flatbuffers::String> bias) {
+    fbb_.AddOffset(8, bias);
+  }
+  explicit OpInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  OpInfoBuilder &operator=(const OpInfoBuilder &);
+  flatbuffers::Offset<OpInfo> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<OpInfo>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<OpInfo> CreateOpInfo(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::String> op = 0,
+    flatbuffers::Offset<flatbuffers::String> weight = 0,
+    flatbuffers::Offset<flatbuffers::String> bias = 0) {
+  OpInfoBuilder builder_(_fbb);
+  builder_.add_bias(bias);
+  builder_.add_weight(weight);
+  builder_.add_op(op);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<OpInfo> CreateOpInfo(flatbuffers::FlatBufferBuilder &_fbb, const OpInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct KVT : public flatbuffers::NativeTable {
+  typedef KV TableType;
+  std::string key;
+  std::string value;
+  KVT() {
+  }
+};
+
+struct KV FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef KVT NativeTableType;
+  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+    return KVTypeTable();
+  }
+  const flatbuffers::String *key() const {
+    return GetPointer<const flatbuffers::String *>(4);
+  }
+  const flatbuffers::String *value() const {
+    return GetPointer<const flatbuffers::String *>(6);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, 4) &&
+           verifier.VerifyString(key()) &&
+           VerifyOffset(verifier, 6) &&
+           verifier.VerifyString(value()) &&
+           verifier.EndTable();
+  }
+  KVT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(KVT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<KV> Pack(flatbuffers::FlatBufferBuilder &_fbb, const KVT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct KVBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_key(flatbuffers::Offset<flatbuffers::String> key) {
+    fbb_.AddOffset(4, key);
+  }
+  void add_value(flatbuffers::Offset<flatbuffers::String> value) {
+    fbb_.AddOffset(6, value);
+  }
+  explicit KVBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  KVBuilder &operator=(const KVBuilder &);
+  flatbuffers::Offset<KV> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<KV>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<KV> CreateKV(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::String> key = 0,
+    flatbuffers::Offset<flatbuffers::String> value = 0) {
+  KVBuilder builder_(_fbb);
+  builder_.add_value(value);
+  builder_.add_key(key);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<KV> CreateKV(flatbuffers::FlatBufferBuilder &_fbb, const KVT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct TrainInfoT : public flatbuffers::NativeTable {
+  typedef TrainInfo TableType;
+  std::vector<std::unique_ptr<KVT>> trainables;
+  std::vector<std::unique_ptr<OpInfoT>> convolutions;
+  std::vector<std::unique_ptr<KVT>> batchnormal;
+  TrainInfoT() {
+  }
+};
+
+struct TrainInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef TrainInfoT NativeTableType;
+  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+    return TrainInfoTypeTable();
+  }
+  const flatbuffers::Vector<flatbuffers::Offset<KV>> *trainables() const {
+    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<KV>> *>(4);
+  }
+  const flatbuffers::Vector<flatbuffers::Offset<OpInfo>> *convolutions() const {
+    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<OpInfo>> *>(6);
+  }
+  const flatbuffers::Vector<flatbuffers::Offset<KV>> *batchnormal() const {
+    return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<KV>> *>(8);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, 4) &&
+           verifier.VerifyVector(trainables()) &&
+           verifier.VerifyVectorOfTables(trainables()) &&
+           VerifyOffset(verifier, 6) &&
+           verifier.VerifyVector(convolutions()) &&
+           verifier.VerifyVectorOfTables(convolutions()) &&
+           VerifyOffset(verifier, 8) &&
+           verifier.VerifyVector(batchnormal()) &&
+           verifier.VerifyVectorOfTables(batchnormal()) &&
+           verifier.EndTable();
+  }
+  TrainInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(TrainInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<TrainInfo> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TrainInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TrainInfoBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_trainables(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<KV>>> trainables) {
+    fbb_.AddOffset(4, trainables);
+  }
+  void add_convolutions(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OpInfo>>> convolutions) {
+    fbb_.AddOffset(6, convolutions);
+  }
+  void add_batchnormal(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<KV>>> batchnormal) {
+    fbb_.AddOffset(8, batchnormal);
+  }
+  explicit TrainInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  TrainInfoBuilder &operator=(const TrainInfoBuilder &);
+  flatbuffers::Offset<TrainInfo> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<TrainInfo>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<TrainInfo> CreateTrainInfo(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<KV>>> trainables = 0,
+    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OpInfo>>> convolutions = 0,
+    flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<KV>>> batchnormal = 0) {
+  TrainInfoBuilder builder_(_fbb);
+  builder_.add_batchnormal(batchnormal);
+  builder_.add_convolutions(convolutions);
+  builder_.add_trainables(trainables);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<TrainInfo> CreateTrainInfo(flatbuffers::FlatBufferBuilder &_fbb, const TrainInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+inline OpInfoT *OpInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new OpInfoT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void OpInfo::UnPackTo(OpInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = op(); if (_e) _o->op = _e->str(); };
+  { auto _e = weight(); if (_e) _o->weight = _e->str(); };
+  { auto _e = bias(); if (_e) _o->bias = _e->str(); };
+}
+
+inline flatbuffers::Offset<OpInfo> OpInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OpInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateOpInfo(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<OpInfo> CreateOpInfo(flatbuffers::FlatBufferBuilder &_fbb, const OpInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OpInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _op = _o->op.empty() ? 0 : _fbb.CreateString(_o->op);
+  auto _weight = _o->weight.empty() ? 0 : _fbb.CreateString(_o->weight);
+  auto _bias = _o->bias.empty() ? 0 : _fbb.CreateString(_o->bias);
+  return MNNTrain::CreateOpInfo(
+      _fbb,
+      _op,
+      _weight,
+      _bias);
+}
+
+inline KVT *KV::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new KVT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void KV::UnPackTo(KVT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = key(); if (_e) _o->key = _e->str(); };
+  { auto _e = value(); if (_e) _o->value = _e->str(); };
+}
+
+inline flatbuffers::Offset<KV> KV::Pack(flatbuffers::FlatBufferBuilder &_fbb, const KVT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateKV(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<KV> CreateKV(flatbuffers::FlatBufferBuilder &_fbb, const KVT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const KVT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key);
+  auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value);
+  return MNNTrain::CreateKV(
+      _fbb,
+      _key,
+      _value);
+}
+
+inline TrainInfoT *TrainInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new TrainInfoT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void TrainInfo::UnPackTo(TrainInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = trainables(); if (_e) { _o->trainables.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->trainables[_i] = std::unique_ptr<KVT>(_e->Get(_i)->UnPack(_resolver)); } } };
+  { auto _e = convolutions(); if (_e) { _o->convolutions.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->convolutions[_i] = std::unique_ptr<OpInfoT>(_e->Get(_i)->UnPack(_resolver)); } } };
+  { auto _e = batchnormal(); if (_e) { _o->batchnormal.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->batchnormal[_i] = std::unique_ptr<KVT>(_e->Get(_i)->UnPack(_resolver)); } } };
+}
+
+inline flatbuffers::Offset<TrainInfo> TrainInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TrainInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateTrainInfo(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TrainInfo> CreateTrainInfo(flatbuffers::FlatBufferBuilder &_fbb, const TrainInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TrainInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _trainables = _o->trainables.size() ? _fbb.CreateVector<flatbuffers::Offset<KV>> (_o->trainables.size(), [](size_t i, _VectorArgs *__va) { return CreateKV(*__va->__fbb, __va->__o->trainables[i].get(), __va->__rehasher); }, &_va ) : 0;
+  auto _convolutions = _o->convolutions.size() ? _fbb.CreateVector<flatbuffers::Offset<OpInfo>> (_o->convolutions.size(), [](size_t i, _VectorArgs *__va) { return CreateOpInfo(*__va->__fbb, __va->__o->convolutions[i].get(), __va->__rehasher); }, &_va ) : 0;
+  auto _batchnormal = _o->batchnormal.size() ? _fbb.CreateVector<flatbuffers::Offset<KV>> (_o->batchnormal.size(), [](size_t i, _VectorArgs *__va) { return CreateKV(*__va->__fbb, __va->__o->batchnormal[i].get(), __va->__rehasher); }, &_va ) : 0;
+  return MNNTrain::CreateTrainInfo(
+      _fbb,
+      _trainables,
+      _convolutions,
+      _batchnormal);
+}
+
+inline const flatbuffers::TypeTable *OpInfoTypeTable() {
+  static const flatbuffers::TypeCode type_codes[] = {
+    { flatbuffers::ET_STRING, 0, -1 },
+    { flatbuffers::ET_STRING, 0, -1 },
+    { flatbuffers::ET_STRING, 0, -1 }
+  };
+  static const char * const names[] = {
+    "op",
+    "weight",
+    "bias"
+  };
+  static const flatbuffers::TypeTable tt = {
+    flatbuffers::ST_TABLE, 3, type_codes, nullptr, nullptr, names
+  };
+  return &tt;
+}
+
+inline const flatbuffers::TypeTable *KVTypeTable() {
+  static const flatbuffers::TypeCode type_codes[] = {
+    { flatbuffers::ET_STRING, 0, -1 },
+    { flatbuffers::ET_STRING, 0, -1 }
+  };
+  static const char * const names[] = {
+    "key",
+    "value"
+  };
+  static const flatbuffers::TypeTable tt = {
+    flatbuffers::ST_TABLE, 2, type_codes, nullptr, nullptr, names
+  };
+  return &tt;
+}
+
+inline const flatbuffers::TypeTable *TrainInfoTypeTable() {
+  static const flatbuffers::TypeCode type_codes[] = {
+    { flatbuffers::ET_SEQUENCE, 1, 0 },
+    { flatbuffers::ET_SEQUENCE, 1, 1 },
+    { flatbuffers::ET_SEQUENCE, 1, 0 }
+  };
+  static const flatbuffers::TypeFunction type_refs[] = {
+    KVTypeTable,
+    OpInfoTypeTable
+  };
+  static const char * const names[] = {
+    "trainables",
+    "convolutions",
+    "batchnormal"
+  };
+  static const flatbuffers::TypeTable tt = {
+    flatbuffers::ST_TABLE, 3, type_codes, type_refs, nullptr, names
+  };
+  return &tt;
+}
+
+}  // namespace MNNTrain
+
+#endif  // FLATBUFFERS_GENERATED_TRAININFO_MNNTRAIN_H_

+ 32 - 11
schema/current/UserDefine_generated.h

@@ -59,15 +59,17 @@ enum BorderMode {
   BorderMode_ZEROS = 0,
   BorderMode_CLAMP = 1,
   BorderMode_REFLECTION = 2,
+  BorderMode_CUBE = 3,
   BorderMode_MIN = BorderMode_ZEROS,
-  BorderMode_MAX = BorderMode_REFLECTION
+  BorderMode_MAX = BorderMode_CUBE
 };
 
-inline const BorderMode (&EnumValuesBorderMode())[3] {
+inline const BorderMode (&EnumValuesBorderMode())[4] {
   static const BorderMode values[] = {
     BorderMode_ZEROS,
     BorderMode_CLAMP,
-    BorderMode_REFLECTION
+    BorderMode_REFLECTION,
+    BorderMode_CUBE
   };
   return values;
 }
@@ -77,13 +79,14 @@ inline const char * const *EnumNamesBorderMode() {
     "ZEROS",
     "CLAMP",
     "REFLECTION",
+    "CUBE",
     nullptr
   };
   return names;
 }
 
 inline const char *EnumNameBorderMode(BorderMode e) {
-  if (e < BorderMode_ZEROS || e > BorderMode_REFLECTION) return "";
+  if (e < BorderMode_ZEROS || e > BorderMode_CUBE) return "";
   const size_t index = static_cast<int>(e);
   return EnumNamesBorderMode()[index];
 }
@@ -293,10 +296,12 @@ struct GridSampleT : public flatbuffers::NativeTable {
   SampleMode mode;
   BorderMode paddingMode;
   bool alignCorners;
+  bool backward;
   GridSampleT()
       : mode(SampleMode_BILINEAR),
         paddingMode(BorderMode_ZEROS),
-        alignCorners(false) {
+        alignCorners(false),
+        backward(false) {
   }
 };
 
@@ -314,11 +319,15 @@ struct GridSample FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   bool alignCorners() const {
     return GetField<uint8_t>(8, 0) != 0;
   }
+  bool backward() const {
+    return GetField<uint8_t>(10, 0) != 0;
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyField<int8_t>(verifier, 4) &&
            VerifyField<int8_t>(verifier, 6) &&
            VerifyField<uint8_t>(verifier, 8) &&
+           VerifyField<uint8_t>(verifier, 10) &&
            verifier.EndTable();
   }
   GridSampleT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -338,6 +347,9 @@ struct GridSampleBuilder {
   void add_alignCorners(bool alignCorners) {
     fbb_.AddElement<uint8_t>(8, static_cast<uint8_t>(alignCorners), 0);
   }
+  void add_backward(bool backward) {
+    fbb_.AddElement<uint8_t>(10, static_cast<uint8_t>(backward), 0);
+  }
   explicit GridSampleBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -354,8 +366,10 @@ inline flatbuffers::Offset<GridSample> CreateGridSample(
     flatbuffers::FlatBufferBuilder &_fbb,
     SampleMode mode = SampleMode_BILINEAR,
     BorderMode paddingMode = BorderMode_ZEROS,
-    bool alignCorners = false) {
+    bool alignCorners = false,
+    bool backward = false) {
   GridSampleBuilder builder_(_fbb);
+  builder_.add_backward(backward);
   builder_.add_alignCorners(alignCorners);
   builder_.add_paddingMode(paddingMode);
   builder_.add_mode(mode);
@@ -569,6 +583,7 @@ inline void GridSample::UnPackTo(GridSampleT *_o, const flatbuffers::resolver_fu
   { auto _e = mode(); _o->mode = _e; };
   { auto _e = paddingMode(); _o->paddingMode = _e; };
   { auto _e = alignCorners(); _o->alignCorners = _e; };
+  { auto _e = backward(); _o->backward = _e; };
 }
 
 inline flatbuffers::Offset<GridSample> GridSample::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GridSampleT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -582,11 +597,13 @@ inline flatbuffers::Offset<GridSample> CreateGridSample(flatbuffers::FlatBufferB
   auto _mode = _o->mode;
   auto _paddingMode = _o->paddingMode;
   auto _alignCorners = _o->alignCorners;
+  auto _backward = _o->backward;
   return MNN::CreateGridSample(
       _fbb,
       _mode,
       _paddingMode,
-      _alignCorners);
+      _alignCorners,
+      _backward);
 }
 
 inline ImageProcessParamT *ImageProcessParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -667,6 +684,7 @@ inline const flatbuffers::TypeTable *BorderModeTypeTable() {
   static const flatbuffers::TypeCode type_codes[] = {
     { flatbuffers::ET_CHAR, 0, 0 },
     { flatbuffers::ET_CHAR, 0, 0 },
+    { flatbuffers::ET_CHAR, 0, 0 },
     { flatbuffers::ET_CHAR, 0, 0 }
   };
   static const flatbuffers::TypeFunction type_refs[] = {
@@ -675,10 +693,11 @@ inline const flatbuffers::TypeTable *BorderModeTypeTable() {
   static const char * const names[] = {
     "ZEROS",
     "CLAMP",
-    "REFLECTION"
+    "REFLECTION",
+    "CUBE"
   };
   static const flatbuffers::TypeTable tt = {
-    flatbuffers::ST_ENUM, 3, type_codes, type_refs, nullptr, names
+    flatbuffers::ST_ENUM, 4, type_codes, type_refs, nullptr, names
   };
   return &tt;
 }
@@ -789,6 +808,7 @@ inline const flatbuffers::TypeTable *GridSampleTypeTable() {
   static const flatbuffers::TypeCode type_codes[] = {
     { flatbuffers::ET_CHAR, 0, 0 },
     { flatbuffers::ET_CHAR, 0, 1 },
+    { flatbuffers::ET_BOOL, 0, -1 },
     { flatbuffers::ET_BOOL, 0, -1 }
   };
   static const flatbuffers::TypeFunction type_refs[] = {
@@ -798,10 +818,11 @@ inline const flatbuffers::TypeTable *GridSampleTypeTable() {
   static const char * const names[] = {
     "mode",
     "paddingMode",
-    "alignCorners"
+    "alignCorners",
+    "backward"
   };
   static const flatbuffers::TypeTable tt = {
-    flatbuffers::ST_TABLE, 3, type_codes, type_refs, nullptr, names
+    flatbuffers::ST_TABLE, 4, type_codes, type_refs, nullptr, names
   };
   return &tt;
 }

+ 14 - 0
schema/default/CaffeOp.fbs

@@ -247,6 +247,20 @@ table Scale {
     external:[int64]; // [offset, scaleData_bytes_size, biasData_bytes_size]
 }
 
+table QuantizeLinear {
+    scaleSize: int;
+    scaleAxis: int;
+    scaleData:[float];
+    zeroPointData:[byte];
+}
+
+table DequantizeLinear {
+    scaleSize: int;
+    scaleAxis: int;
+    scaleData:[float];
+    zeroPointData:[byte];
+}
+
 enum EltwiseType : byte {
     PROD = 0,
     SUM = 1,

+ 8 - 4
schema/default/MNN.fbs

@@ -72,10 +72,10 @@ enum OpType : int {
     QuantizedConcat,
     QuantizedDepthwiseConv2D,
     QuantizedLogistic,
-    QuantizedMatMul,
+    RasterAndInterpolate,
     QuantizedMaxPool,
-    QuantizedRelu,
-    QuantizedRelu6,
+    Texture,
+    RasterDiff,
     QuantizedReshape,
     QuantizedSoftmax,
     QuantizeMaxMin,
@@ -167,6 +167,8 @@ enum OpType : int {
     GatherElements = 152,
     Svd = 153,
     Histogram = 154,
+    QuantizeLinear = 155,
+    DequantizeLinear = 156,
 
     Plugin = 256, //The Type load from plugin
     //Training Op Start from 257
@@ -389,7 +391,9 @@ union OpParameter {
     GridSample,
     LoopParam,
     ImageProcessParam,
-    CumSum
+    CumSum,
+    QuantizeLinear,
+    DequantizeLinear,
 }
 
 table Op {

+ 18 - 0
schema/default/TrainInfo.fbs

@@ -0,0 +1,18 @@
+namespace MNNTrain;
+
+table OpInfo {
+    op:string;
+    weight:string;
+    bias:string;
+}
+
+table KV {
+    key:string;
+    value:string;
+}
+
+table TrainInfo {
+    trainables:[KV];
+    convolutions:[OpInfo];
+    batchnormal:[KV];
+}

+ 3 - 1
schema/default/UserDefine.fbs

@@ -12,13 +12,15 @@ enum SampleMode : byte {
 enum BorderMode : byte {
     ZEROS=0,
     CLAMP,
-    REFLECTION
+    REFLECTION,
+    CUBE
 }
 
 table GridSample {
     mode:SampleMode;
     paddingMode:BorderMode;
     alignCorners:bool=false;
+    backward:bool=false;
 }
 
 enum ImageFormatType : int {

+ 6 - 6
source/backend/arm82/Arm82Binary.cpp

@@ -199,37 +199,37 @@ void Arm82Binary(void *dstRaw, const void *src0Raw, const void *src1Raw, const i
 }
 
 
-struct VecBinaryAdd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinaryAdd {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vaddq_f16(x, y);
     }
 };
 
-struct VecBinarySub : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinarySub {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vsubq_f16(x, y);
     }
 };
 
-struct VecBinaryMul : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinaryMul {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vmulq_f16(x, y);
     }
 };
 
-struct VecBinaryMin : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinaryMin {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vminq_f16(x, y);
     }
 };
 
-struct VecBinaryMax : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinaryMax {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vmaxq_f16(x, y);
     }
 };
 
-struct VecBinarySqd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
+struct VecBinarySqd {
     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
         return vmulq_f16(vsubq_f16(x, y), vsubq_f16(x, y));
     }

+ 33 - 66
source/backend/arm82/Arm82Functions.cpp

@@ -10,6 +10,13 @@
 #include "Arm82Relu.hpp"
 #include "backend/cpu/compute/CommonOptFunction.h"
 #include "backend/cpu/CPUPool.hpp"
+#include "backend/cpu/CPURuntime.hpp"
+
+#define FLOAT FLOAT16
+#define PACK 8
+using Vec = MNN::Math::Vec<FLOAT16, 8>;
+
+#include "backend/cpu/GridSampler.hpp"
 
 #if defined(MNN_USE_NEON)
 #include <arm_neon.h>
@@ -33,6 +40,13 @@ void MNNPackedMatMulFP16_int4(float* C, const float* A, const float* B, const si
 void MNNPackedMatMulRemainFP16_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
 void MNNPackedMatMulFP16_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
 void MNNPackedMatMulRemainFP16_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
+
+void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
+void MNNQuantScaleFP16(float* sum, float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
+void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack);
+void MNNQuantSumFP16(float* sum, const float* dequant_scale, size_t thread, size_t batch);
+void MNNGemmHybridInt8FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt4FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
 #endif
 
 void MNNConvDwF23MulTransUnitFP16(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, size_t ow);
@@ -43,7 +57,6 @@ void MNNConvRunForLineDepthwiseFP16(float* dst, const float* src, const float* w
                                 size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep);
 }
 
-using Vec = MNN::Math::Vec<FLOAT16, 8>;
 
 namespace MNN {
 
@@ -153,70 +166,6 @@ static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_
     ::memcpy(dst, tempDst, areaRemain * 2 * sizeof(int16_t));
 }
 
-static size_t MNNGridSampleComputeOffsetFP16(int h, int w, int height, int width, bool padMode) {
-    if (padMode == true) { //padMode == BorderMode_ZEROS
-        if (h < 0 || h >= height || w < 0 || w >= width) {
-            return -1;
-        }
-    } else {
-        // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
-        // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
-        // the leftover reflections degrade to GridSamplePaddingMode_BORDER
-        h = h < 0 ? 0 : (h > (height - 1) ? (height - 1) : h);
-        w = w < 0 ? 0 : (w > (width - 1) ? (width - 1) : w);
-    }
-    return h * width * 8 + w * 8;
-}
-
-static void MNNGridSampleInterpFP16(FLOAT16* outputPtr, const FLOAT16* inputPtr, const FLOAT16* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) {
-    for (auto ow = 0; ow < outW; ++ow) {
-        auto w_fp16 = cordPtr[2 * ow + 0];
-        auto h_fp16 = cordPtr[2 * ow + 1];
-        float w = (float)(w_fp16);
-        float h = (float)(h_fp16);
-        Vec interp;
-
-        if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
-            int nh = ::floor(h + 0.5f);
-            int nw = ::floor(w + 0.5f);
-            size_t ns = MNNGridSampleComputeOffsetFP16(nh, nw, inH, inW, padMode);
-            for (int k = 0; k < channelCUnit; ++k) {
-                interp = ns == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + ns);
-                Vec::save(outputPtr + k * outOffset + 8 * ow, interp);
-            }
-        } else { //sampleMode == GridSampleMode_BILINEAR
-            int w0_h = ::floor(h);
-            int w0_w = ::floor(w);
-            int w1_h = ::ceil(h);
-            int w1_w = ::ceil(w);
-            auto oneV = Vec((FLOAT16)1);
-
-            auto f0 = Vec((FLOAT16)w1_w - w_fp16);
-            auto f1 = oneV - f0;
-            auto h0 = Vec((FLOAT16)w1_h - h_fp16);
-            auto h1 = oneV - h0;
-
-            size_t s00 = MNNGridSampleComputeOffsetFP16(w0_h, w0_w, inH, inW, padMode);
-            size_t s01 = MNNGridSampleComputeOffsetFP16(w0_h, w1_w, inH, inW, padMode);
-            size_t s10 = MNNGridSampleComputeOffsetFP16(w1_h, w0_w, inH, inW, padMode);
-            size_t s11 = MNNGridSampleComputeOffsetFP16(w1_h, w1_w, inH, inW, padMode);
-
-            for (int k = 0; k < channelCUnit; ++k) {
-                Vec i00 = s00 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s00);
-                Vec i01 = s01 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s01);
-                Vec i10 = s10 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s10);
-                Vec i11 = s11 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s11);
-
-                Vec i0 = i00 * f0 + i01 * f1;
-                Vec i1 = i10 * f0 + i11 * f1;
-
-                interp = i0 * h0 + i1 * h1;
-                Vec::save(outputPtr + k * outOffset + 8 * ow, interp);
-            }
-        }
-    }
-}
-
 static void MNNRoiPoolingMaxFP16(FLOAT16* dst, const FLOAT16* src, int hLen, int wLen, int iw) {
     Vec max = Vec(-65504.0f);
     for (int h = 0; h < hLen; h++, src += iw * 8) {
@@ -697,7 +646,8 @@ bool Arm82Functions::init() {
     gInstance->penalty = 2.0f;
     FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16);
-    FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterpFP16);
+    FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterp);
+    FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterpGrad, MNNGridSampleInterpGrad);
     FUNC_PTR_ASSIGN(gInstance->MNNRoiPoolingMax, MNNRoiPoolingMaxFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNRoiAlignMax, MNNRoiAlignMaxFP16);
     FUNC_PTR_ASSIGN(gInstance->MNNRoiAlignAvg, MNNRoiAlignAvgFP16);
@@ -712,6 +662,23 @@ bool Arm82Functions::init() {
     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int4, MNNPackedMatMulRemainFP16_int4);
     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8);
     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int8, MNNPackedMatMulRemainFP16_int8);
+    FUNC_PTR_ASSIGN(gInstance->MNNAbsMax, MNNAbsMaxFP16);
+    FUNC_PTR_ASSIGN(gInstance->MNNQuantScale, MNNQuantScaleFP16);
+    FUNC_PTR_ASSIGN(gInstance->MNNDynamicQuant, MNNDynamicQuantFP16);
+    FUNC_PTR_ASSIGN(gInstance->MNNQuantSum, MNNQuantSumFP16);
+    cpuinfo_arm_isa gCPUInfo;
+    cpuinfo_arm_init(&gCPUInfo);
+    gInstance->supportFp16arith = gCPUInfo.fp16arith;
+    gInstance->supportSDot = gCPUInfo.dot;
+    gInstance->supportI8mm = gCPUInfo.i8mm;
+    if (gInstance->supportSDot) {
+        gInstance->MNNGemmHybridInt8 = MNNGemmHybridInt8FP16_sdot;
+        gInstance->MNNGemmHybridInt4 = MNNGemmHybridInt4FP16_sdot;
+    }
+    if (gInstance->supportI8mm) {
+        gInstance->MNNGemmHybridInt8 = MNNGemmHybridInt8FP16_smmla;
+        gInstance->MNNGemmHybridInt4 = MNNGemmHybridInt4FP16_smmla;
+    }
 #endif
     FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
     FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);

+ 4 - 2
source/backend/arm82/asm/arm32/MNNGeluFP16.S

@@ -62,8 +62,10 @@ vadd.f16 q3, q3, q1
 vmul.f16 q2, q2, q14
 vmul.f16 q3, q3, q14
 
-vmov.f16 q4, #5.0
-vmov.f16 q5, #-5.0
+mov lr, #5.0
+vdup.16 q4, lr
+mov lr, #-5.0
+vdup.16 q5, lr
 vmax.f16 q2, q2, q5
 vmin.f16 q2, q2, q4
 vmax.f16 q3, q3, q5

+ 5 - 2
source/backend/arm82/asm/arm64/MNNGeluFP16.S

@@ -45,6 +45,9 @@ dup v10.8h, w9        // v10: [28.f]x4
 dup v9.8h, w10        // v9: [3150.f]x4
 dup v8.8h, w11        // v8: [62370.f]x4
 
+mov w4, #5.0
+mov w5, #-5.0
+
 GeluZLoop:
 
 ld1 {v0.8h, v1.8h}, [x1], #32   // v0, v1: fp32x4
@@ -62,8 +65,8 @@ fadd v3.8h, v3.8h, v1.8h
 fmul v2.8h, v2.8h, v14.8h
 fmul v3.8h, v3.8h, v14.8h
 
-fmov v6.8h, #-5
-fmov v7.8h, #5
+dup v6.8h, w5
+dup v7.8h, w4
 fmin v2.8h, v2.8h, v7.8h
 fmin v3.8h, v3.8h, v7.8h
 fmax v2.8h, v2.8h, v6.8h

+ 247 - 0
source/backend/arm82/asm/arm64/low_memory/MNNAbsMaxFP16.S

@@ -0,0 +1,247 @@
+//
+//  MNNAbsMaxFP16.S
+//  MNN
+//
+//  Created by MNN on 2023/10/31.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+.macro Add d0, d1, d2, d3, z0, z1, z2, z3
+    fadd \d0\().8h, \d0\().8h, \z0\().8h
+    fadd \d1\().8h, \d1\().8h, \z1\().8h
+    fadd \d2\().8h, \d2\().8h, \z2\().8h
+    fadd \d3\().8h, \d3\().8h, \z3\().8h
+.endm
+
+.macro Abs z0, z1, z2, z3
+    fabs \z0\().8h, \z0\().8h
+    fabs \z1\().8h, \z1\().8h
+    fabs \z2\().8h, \z2\().8h
+    fabs \z3\().8h, \z3\().8h
+.endm
+
+.macro Max d0, d1, d2, d3, z0, z1, z2, z3
+    fmax \d0\().8h, \d0\().8h, \z0\().8h
+    fmax \d1\().8h, \d1\().8h, \z1\().8h
+    fmax \d2\().8h, \d2\().8h, \z2\().8h
+    fmax \d3\().8h, \d3\().8h, \z3\().8h
+.endm
+
+.macro ReduceSum s0, s1, s2, s3, zero
+    faddp \s0\().8h, \s0\().8h, \s1\().8h // 0 0 0 0 1 1 1 1
+    faddp \s2\().8h, \s2\().8h, \s3\().8h // 2 2 2 2 3 3 3 3
+    faddp \s0\().8h, \s0\().8h, \s2\().8h // 0 0 1 1 2 2 3 3
+    faddp \s0\().8h, \s0\().8h, \zero\().8h // 0 1 2 3
+.endm
+
+.macro ReduceMax s0, s1, s2, s3, zero
+    fmaxp \s0\().8h, \s0\().8h, \s1\().8h // 0 0 0 0 1 1 1 1
+    fmaxp \s2\().8h, \s2\().8h, \s3\().8h // 2 2 2 2 3 3 3 3
+    fmaxp \s0\().8h, \s0\().8h, \s2\().8h // 0 0 1 1 2 2 3 3
+    fmaxp \s0\().8h, \s0\().8h, \zero\().8h // 0 1 2 3
+.endm
+
+//void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack)
+asm_function MNNAbsMaxFP16
+
+// x0: source, x1:absmax, x2:src_depth_quad, x3:realSize, x4: pack(no used)
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x6, x3, #4 // src_step = batch * 8 * sizeof(float16_t) = batch << 4
+
+TILE_12:
+cmp x3, #12
+blt TILE_10
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+sub x8, x6, #128 // src_step
+
+// absmax: v0-11
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x7], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x7], #64
+ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x7], x8
+Abs v0, v1, v2, v3
+Abs v4, v5, v6, v7
+Abs v8, v9, v10, v11
+subs x5, x5, #1
+beq Tile12End
+
+LoopSz_12:
+ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x7], #64
+ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x7], #64
+ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x7], x8
+// absmax = fmax(absmax, abs(x))
+Abs v12, v13, v14, v15
+Abs v16, v17, v18, v19
+Abs v20, v21, v22, v23
+Max v0, v1, v2, v3, v12, v13, v14, v15
+Max v4, v5, v6, v7, v16, v17, v18, v19
+Max v8, v9, v10, v11, v20, v21, v22, v23
+
+subs x5, x5, #1
+bne LoopSz_12
+
+Tile12End:
+movi v28.8h, #0
+scvtf v28.8h, v28.8h
+
+ReduceMax v0, v1, v2, v3, v28
+ReduceMax v4, v5, v6, v7, v28
+ReduceMax v8, v9, v10, v11, v28
+mov v0.d[1], v4.d[0]
+st1 {v0.8h}, [x1], #16
+st1 {v8.d}[0], [x1], #8
+
+sub x3, x3, #12
+add x0, x0, #192 // src += 12 * 8 * 2
+b TILE_12
+
+TILE_10:
+cmp x3, #10
+blt TILE_8
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+sub x8, x6, #128 // src_step
+
+// absmax: v0-9
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x7], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x7], #64
+ld1 {v8.8h, v9.8h}, [x7], x8
+Abs v0, v1, v2, v3
+Abs v4, v5, v6, v7
+fabs v8.8h, v8.8h
+fabs v9.8h, v9.8h
+
+subs x5, x5, #1
+beq Tile10End
+
+LoopSz_10:
+ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x7], #64
+ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x7], #64
+ld1 {v28.8h, v29.8h}, [x7], x8
+
+// absmax = fmax(absmax, abs(x))
+Abs v20, v21, v22, v23
+Abs v24, v25, v26, v27
+fabs v28.8h, v28.8h
+fabs v29.8h, v29.8h
+
+Max v0, v1, v2, v3, v20, v21, v22, v23
+Max v4, v5, v6, v7, v24, v25, v26, v27
+fmax v8.8h, v8.8h, v28.8h
+fmax v9.8h, v9.8h, v29.8h
+
+subs x5, x5, #1
+bne LoopSz_10
+
+Tile10End:
+movi v24.8h, #0
+scvtf v24.8h, v24.8h
+fmaxp v0.8h, v0.8h, v1.8h
+fmaxp v0.8h, v0.8h, v24.8h
+fmaxp v0.8h, v0.8h, v24.8h
+st1 {v0.s}[0], [x1], #4
+ReduceMax v2, v3, v4, v5, v24
+ReduceMax v6, v7, v8, v9, v24
+mov v2.d[1], v6.d[0]
+st1 {v2.8h}, [x1], #16
+
+sub x3, x3, #10
+add x0, x0, #160 // src += 10 * 8 * 2
+b TILE_10
+
+TILE_8:
+cmp x3, #8
+blt TILE_1
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+sub x8, x6, #64 // src_step
+
+// absmax: v0-7
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x7], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x7], x8
+
+Abs v0, v1, v2, v3
+Abs v4, v5, v6, v7
+
+subs x5, x5, #1
+beq Tile8End
+
+LoopSz_8:
+ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x7], #64
+ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x7], x8
+
+// absmax = fmax(absmax, abs(x))
+Abs v16, v17, v18, v19
+Abs v20, v21, v22, v23
+Max v0, v1, v2, v3, v16, v17, v18, v19
+Max v4, v5, v6, v7, v20, v21, v22, v23
+
+subs x5, x5, #1
+bne LoopSz_8
+
+Tile8End:
+movi v24.8h, #0
+scvtf v24.8h, v24.8h
+ReduceMax v0, v1, v2, v3, v24
+ReduceMax v4, v5, v6, v7, v24
+
+mov v0.d[1], v4.d[0]
+st1 {v0.8h}, [x1], #16
+sub x3, x3, #8
+add x0, x0, #128 // src += 8 * 8 * 2
+b TILE_8
+
+TILE_1:
+cmp x3, #1
+blt End
+movi v17.8h, #0
+scvtf v17.8h, v17.8h
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+
+// absmax: v0
+ld1 {v0.8h}, [x7], x6
+fabs v0.8h, v0.8h
+subs x5, x5, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v16.8h}, [x7], x6
+
+// absmax = fmax(absmax, abs(x))
+fabs v16.8h, v16.8h
+fmax v0.8h, v0.8h, v16.8h
+
+subs x5, x5, #1
+bne LoopSz_1
+
+Tile1End:
+
+fmaxp v2.8h, v0.8h, v17.8h
+fmaxp v0.8h, v2.8h, v17.8h
+fmaxp v2.8h, v0.8h, v17.8h
+st1 {v2.h}[0], [x1], #2
+
+sub x3, x3, #1
+add x0, x0, #16 // src += 1 * 8 * 2
+b TILE_1
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif

+ 393 - 0
source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S

@@ -0,0 +1,393 @@
+//
+//  MNNDynamicQuantFP16.S
+//  MNN
+//
+//  Created by MNN on 2023/10/31.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+.macro Round z0, z1, z2, z3
+    fcvtas \z0\().8h, \z0\().8h
+    fcvtas \z1\().8h, \z1\().8h
+    fcvtas \z2\().8h, \z2\().8h
+    fcvtas \z3\().8h, \z3\().8h
+.endm
+
+//void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize)
+asm_function MNNDynamicQuantFP16
+
+// x0: src, x1:dst, x2:scale, x3:sum, x4:src_depth_quad, x5:realSize
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x6, x5, #3  // dst_step = batch * unit * sizeof(int8_t) = batch * 8 = batch << 3
+lsl x7, x6, #1  // src_step = dst_step * 2 (float16_t) = dst_step << 1
+
+movi v29.16b, #1
+
+TILE_12:
+cmp x5, #12
+blt TILE_10
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+sub x13, x7, #128 // src_step - 64
+sub x14, x6, #64 // dst_step - 64
+
+// quant_scale: v12, v13
+ld1 {v12.8h}, [x2], #16
+ld1 {v13.d}[0], [x2], #8
+movi v23.4s, #0
+movi v24.4s, #0
+movi v25.4s, #0
+movi v26.4s, #0
+movi v27.4s, #0
+movi v28.4s, #0
+
+LoopSz_12:
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x9], #64
+ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x9], x13
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v12.h[0]
+fmul v1.8h, v1.8h, v12.h[1]
+fmul v2.8h, v2.8h, v12.h[2]
+fmul v3.8h, v3.8h, v12.h[3]
+fmul v4.8h, v4.8h, v12.h[4]
+fmul v5.8h, v5.8h, v12.h[5]
+fmul v6.8h, v6.8h, v12.h[6]
+fmul v7.8h, v7.8h, v12.h[7]
+fmul v8.8h, v8.8h, v13.h[0]
+fmul v9.8h, v9.8h, v13.h[1]
+fmul v10.8h, v10.8h, v13.h[2]
+fmul v11.8h, v11.8h, v13.h[3]
+
+// int16_t x = round(x)
+Round v0, v1, v2, v3
+Round v4, v5, v6, v7
+Round v8, v9, v10, v11
+
+// y = (int8_t)x
+sqxtn v0.8b, v0.8h
+sqxtn2 v0.16b, v1.8h
+sqxtn v1.8b, v2.8h
+sqxtn2 v1.16b, v3.8h
+sqxtn v2.8b, v4.8h
+sqxtn2 v2.16b, v5.8h
+sqxtn v3.8b, v6.8h
+sqxtn2 v3.16b, v7.8h
+sqxtn v4.8b, v8.8h
+sqxtn2 v4.16b, v9.8h
+sqxtn v5.8b, v10.8h
+sqxtn2 v5.16b, v11.8h
+
+.inst 0x4e9d9417 // sdot v23.4s, v0.16b, v29.16b
+.inst 0x4e9d9438 // sdot v24.4s, v1.16b, v29.16b
+.inst 0x4e9d9459 // sdot v25.4s, v2.16b, v29.16b
+.inst 0x4e9d947a // sdot v26.4s, v3.16b, v29.16b
+.inst 0x4e9d949b // sdot v27.4s, v4.16b, v29.16b
+.inst 0x4e9d94bc // sdot v28.4s, v5.16b, v29.16b
+
+st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64
+st1 {v4.16b, v5.16b}, [x10], x14
+
+subs x12, x12, #1
+bne LoopSz_12
+
+addp v22.4s, v23.4s, v24.4s
+addp v23.4s, v25.4s, v26.4s
+addp v24.4s, v27.4s, v28.4s
+st1 {v22.4s, v23.4s, v24.4s}, [x3], #48
+
+Tile12End:
+sub x5, x5, #12   // batch -= 12
+add x0, x0, #192  // src += 12 * 8 * sizeof(float16_t)
+add x1, x1, #96   // dst += 12 * 8 * sizeof(int8_t)
+b TILE_12
+
+TILE_10:
+cmp x5, #10
+blt TILE_8
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+sub x13, x7, #128 // src_step - 64
+sub x14, x6, #64 // dst_step - 64
+
+// quant_scale: v10, v11
+ld1 {v10.8h}, [x2], #16
+ld1 {v11.s}[0], [x2], #4
+movi v24.4s, #0
+movi v25.4s, #0
+movi v26.4s, #0
+movi v27.4s, #0
+movi v28.4s, #0
+
+LoopSz_10:
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x9], #64
+ld1 {v8.8h, v9.8h}, [x9], x13
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v10.h[0]
+fmul v1.8h, v1.8h, v10.h[1]
+fmul v2.8h, v2.8h, v10.h[2]
+fmul v3.8h, v3.8h, v10.h[3]
+fmul v4.8h, v4.8h, v10.h[4]
+fmul v5.8h, v5.8h, v10.h[5]
+fmul v6.8h, v6.8h, v10.h[6]
+fmul v7.8h, v7.8h, v10.h[7]
+fmul v8.8h, v8.8h, v11.h[0]
+fmul v9.8h, v9.8h, v11.h[1]
+
+// int16_t x = round(x)
+Round v0, v1, v2, v3
+Round v4, v5, v6, v7
+fcvtas v8.8h, v8.8h
+fcvtas v9.8h, v9.8h
+
+// y = (int8_t)x
+sqxtn v0.8b, v0.8h
+sqxtn2 v0.16b, v1.8h
+sqxtn v1.8b, v2.8h
+sqxtn2 v1.16b, v3.8h
+sqxtn v2.8b, v4.8h
+sqxtn2 v2.16b, v5.8h
+sqxtn v3.8b, v6.8h
+sqxtn2 v3.16b, v7.8h
+sqxtn v4.8b, v8.8h
+sqxtn2 v4.16b, v9.8h
+
+.inst 0x4e9d9418 // sdot v24.4s, v0.16b, v29.16b
+.inst 0x4e9d9439 // sdot v25.4s, v1.16b, v29.16b
+.inst 0x4e9d945a // sdot v26.4s, v2.16b, v29.16b
+.inst 0x4e9d947b // sdot v27.4s, v3.16b, v29.16b
+.inst 0x4e9d949c // sdot v28.4s, v4.16b, v29.16b
+
+st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64
+st1 {v4.16b}, [x10], x14
+
+subs x12, x12, #1
+bne LoopSz_10
+
+addp v23.4s, v24.4s, v25.4s
+addp v24.4s, v26.4s, v27.4s
+addp v25.4s, v28.4s, v28.4s
+st1 {v23.4s, v24.4s}, [x3], #32
+st1 {v25.d}[0], [x3], #8
+
+Tile10End:
+sub x5, x5, #10   // batch -= 10
+add x0, x0, #160  // src += 10 * 8 * sizeof(float16_t)
+add x1, x1, #80   // dst += 10 * 8 * sizeof(int8_t)
+b TILE_10
+
+
+TILE_8:
+cmp x5, #8
+blt TILE_1
+sub x8, x7, #64 // src_step - 64
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8
+ld1 {v8.8h}, [x2], #16
+movi v25.4s, #0
+movi v26.4s, #0
+movi v27.4s, #0
+movi v28.4s, #0
+
+LoopSz_8:
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64
+ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x9], x8
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v8.h[0]
+fmul v1.8h, v1.8h, v8.h[1]
+fmul v2.8h, v2.8h, v8.h[2]
+fmul v3.8h, v3.8h, v8.h[3]
+fmul v4.8h, v4.8h, v8.h[4]
+fmul v5.8h, v5.8h, v8.h[5]
+fmul v6.8h, v6.8h, v8.h[6]
+fmul v7.8h, v7.8h, v8.h[7]
+
+// int16_t x = round(x)
+Round v0, v1, v2, v3
+Round v4, v5, v6, v7
+
+// y = (int8_t)x
+sqxtn v9.8b, v0.8h
+sqxtn2 v9.16b, v1.8h
+sqxtn v10.8b, v2.8h
+sqxtn2 v10.16b, v3.8h
+sqxtn v11.8b, v4.8h
+sqxtn2 v11.16b, v5.8h
+sqxtn v12.8b, v6.8h
+sqxtn2 v12.16b, v7.8h
+
+.inst 0x4e9d9539 // sdot v25.4s, v9.16b, v29.16b
+.inst 0x4e9d955a // sdot v26.4s, v10.16b, v29.16b
+.inst 0x4e9d957b // sdot v27.4s, v11.16b, v29.16b
+.inst 0x4e9d959c // sdot v28.4s, v12.16b, v29.16b
+
+st1 {v9.16b, v10.16b, v11.16b, v12.16b}, [x10], x6
+
+subs x12, x12, #1
+bne LoopSz_8
+
+addp v24.4s, v25.4s, v26.4s
+addp v25.4s, v27.4s, v28.4s
+st1 {v24.4s, v25.4s}, [x3], #32
+
+Tile8End:
+sub x5, x5, #8    // batch -= 8
+add x0, x0, #128  // src += 8 * 8 * sizeof(float16_t)
+add x1, x1, #64   // dst += 8 * 8 * sizeof(int8_t)
+b TILE_8
+
+TILE_4:
+cmp x5, #4
+blt TILE_2
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8
+ld1 {v8.d}[0], [x2], #8
+movi v27.4s, #0
+movi v28.4s, #0
+
+LoopSz_4:
+ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], x7
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v8.h[0]
+fmul v1.8h, v1.8h, v8.h[1]
+fmul v2.8h, v2.8h, v8.h[2]
+fmul v3.8h, v3.8h, v8.h[3]
+
+// int16_t x = round(x)
+Round v0, v1, v2, v3
+
+// y = (int8_t)x
+sqxtn v4.8b, v0.8h
+sqxtn2 v4.16b, v1.8h
+sqxtn v5.8b, v2.8h
+sqxtn2 v5.16b, v3.8h
+
+.inst 0x4e9d949b // sdot v27.4s, v4.16b, v29.16b
+.inst 0x4e9d94bc // sdot v28.4s, v5.16b, v29.16b
+
+st1 {v4.16b, v5.16b}, [x10], x6
+
+subs x12, x12, #1
+bne LoopSz_4
+
+addp v26.4s, v27.4s, v28.4s
+st1 {v26.4s}, [x3], #16
+
+Tile4End:
+sub x5, x5, #4    // batch -= 4
+add x0, x0, #64   // src += 4 * 8 * sizeof(float16_t)
+add x1, x1, #32   // dst += 4 * 8 * sizeof(int8_t)
+b TILE_4
+
+
+TILE_2:
+cmp x5, #2
+blt TILE_1
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8
+ld1 {v8.s}[0], [x2], #4
+movi v28.4s, #0
+
+LoopSz_2:
+ld1 {v0.8h, v1.8h}, [x9], x7
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v8.h[0]
+fmul v1.8h, v1.8h, v8.h[1]
+
+// int16_t x = round(x)
+fcvtas v0.8h, v0.8h
+fcvtas v1.8h, v1.8h
+
+// y = (int8_t)x
+sqxtn v2.8b, v0.8h
+sqxtn2 v2.16b, v1.8h
+.inst 0x4e9d945c // sdot v28.4s, v2.16b, v29.16b
+
+st1 {v2.16b}, [x10], x6
+
+subs x12, x12, #1
+bne LoopSz_2
+
+addp v27.4s, v28.4s, v28.4s
+st1 {v27.d}[0], [x3], #8
+
+Tile2End:
+sub x5, x5, #2    // batch -= 2
+add x0, x0, #32   // src += 2 * 8 * sizeof(float16_t)
+add x1, x1, #16   // dst += 2 * 8 * sizeof(int8_t)
+b TILE_2
+
+
+TILE_1:
+cmp x5, #1
+blt End
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8
+ld1 {v8.h}[0], [x2], #2
+movi v28.4s, #0
+
+LoopSz_1:
+ld1 {v0.8h}, [x9], x7
+
+// float16_t x = x * quant_scale
+fmul v0.8h, v0.8h, v8.h[0]
+// int16_t x = round(x)
+fcvtas v0.8h, v0.8h
+// y = (int8_t)x
+sqxtn v0.8b, v0.8h
+.inst 0x4e9d941c // sdot v28.4s, v0.16b, v29.16b
+
+st1 {v0.8b}, [x10], x6
+
+subs x12, x12, #1
+bne LoopSz_1
+
+addp v27.4s, v28.4s, v28.4s
+st1 {v27.s}[0], [x3], #4
+
+Tile1End:
+sub x5, x5, #1   // batch -= 1
+add x0, x0, #16  // src += 1 * 8 * sizeof(float16_t)
+add x1, x1, #8   // dst += 1 * 8 * sizeof(int8_t)
+b TILE_1
+
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif

+ 361 - 0
source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_sdot.S

@@ -0,0 +1,361 @@
+//
+//  MNNGemmHybridInt4FP16_sdot.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1
+    fmul \d0\().4s, \d0\().4s, \s\().s[\idx0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[\idx0]
+    fmul \d2\().4s, \d2\().4s, \s\().s[\idx1]
+    fmul \d3\().4s, \d3\().4s, \s\().s[\idx1]
+    fmul \d0\().4s, \d0\().4s, \alpha0\().4s
+    fmul \d1\().4s, \d1\().4s, \alpha1\().4s
+    fmul \d2\().4s, \d2\().4s, \alpha0\().4s
+    fmul \d3\().4s, \d3\().4s, \alpha1\().4s
+.endm
+
+.macro Float32ToHalf s0, s1, s2, s3, d0, d1
+    fcvtn \d0\().4h,  \s0\().4s
+    fcvtn2 \d0\().8h, \s1\().4s
+    fcvtn \d1\().4h,  \s2\().4s
+    fcvtn2 \d1\().8h, \s3\().4s
+.endm
+
+.macro Dequant c0, z0, b0, s0, idx
+    fmla \c0\().8h, \z0\().8h, \s0\().h[\idx]
+    fadd \c0\().8h, \c0\().8h, \b0\().8h
+.endm
+
+asm_function MNNGemmHybridInt4FP16_sdot
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32  = src_depth_quad << 5
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v4.16b, v0.16b, #4
+    and v5.16b, v0.16b, v14.16b
+    sub v4.16b, v4.16b, v15.16b
+    sub v5.16b, v5.16b, v15.16b
+    ushr v6.16b, v1.16b, #4
+    and v7.16b, v1.16b, v14.16b
+    sub v6.16b, v6.16b, v15.16b
+    sub v7.16b, v7.16b, v15.16b
+    zip1 v0.16b, v4.16b, v5.16b
+    zip2 v1.16b, v4.16b, v5.16b
+    zip1 v2.16b, v6.16b, v7.16b
+    zip2 v3.16b, v6.16b, v7.16b
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    mov v10.d[0], v0.d[1]
+    mov v10.d[1], v0.d[0]
+    mov v11.d[1], v1.d[0]
+    mov v11.d[0], v1.d[1]
+    mov v12.d[0], v2.d[1]
+    mov v12.d[1], v2.d[0]
+    mov v13.d[0], v3.d[1]
+    mov v13.d[1], v3.d[0]
+    .inst 0x4e809490 // sdot v16.4s, v4.16b, v0.16b
+    .inst 0x4e8a9498 // sdot v24.4s, v4.16b, v10.16b
+    .inst 0x4e819491 // sdot v17.4s, v4.16b, v1.16b
+    .inst 0x4e8b9499 // sdot v25.4s, v4.16b, v11.16b
+    .inst 0x4e829492 // sdot v18.4s, v4.16b, v2.16b
+    .inst 0x4e8c949a // sdot v26.4s, v4.16b, v12.16b
+    .inst 0x4e839493 // sdot v19.4s, v4.16b, v3.16b
+    .inst 0x4e8d949b // sdot v27.4s, v4.16b, v13.16b
+    .inst 0x4e8094b4 // sdot v20.4s, v5.16b, v0.16b
+    .inst 0x4e8a94bc // sdot v28.4s, v5.16b, v10.16b
+    .inst 0x4e8194b5 // sdot v21.4s, v5.16b, v1.16b
+    .inst 0x4e8b94bd // sdot v29.4s, v5.16b, v11.16b
+    .inst 0x4e8294b6 // sdot v22.4s, v5.16b, v2.16b
+    .inst 0x4e8c94be // sdot v30.4s, v5.16b, v12.16b
+    .inst 0x4e8394b7 // sdot v23.4s, v5.16b, v3.16b
+    .inst 0x4e8d94bf // sdot v31.4s, v5.16b, v13.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+    addp v16.4s, v16.4s, v24.4s
+    addp v17.4s, v17.4s, v25.4s
+    addp v18.4s, v18.4s, v26.4s
+    addp v19.4s, v19.4s, v27.4s
+    addp v20.4s, v20.4s, v28.4s
+    addp v21.4s, v21.4s, v29.4s
+    addp v22.4s, v22.4s, v30.4s
+    addp v23.4s, v23.4s, v31.4s
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v4.d}[0], [x23]  // scales
+    ld1 {v31.8h}, [x19], #16  // alpha
+    uzp1 v24.4s, v16.4s, v17.4s // batch=0,oc:0-3
+    uzp2 v26.4s, v16.4s, v17.4s // batch=1,oc:1,0,3,2
+    uzp1 v25.4s, v18.4s, v19.4s // batch=0,oc:4-7
+    uzp2 v27.4s, v18.4s, v19.4s // batch=1,oc:5,4,7,6
+
+    uzp1 v28.4s, v20.4s, v21.4s // batch=2,oc:0-3
+    uzp2 v7.4s, v20.4s, v21.4s  // batch=3,oc:1,0,3,2
+    uzp1 v6.4s, v22.4s, v23.4s  // batch=2,oc:4-7
+    uzp2 v8.4s, v22.4s, v23.4s  // batch=3,oc:5,4,7,6
+
+    trn1 v0.4s, v26.4s, v27.4s // 1,5,3,7
+    trn1 v1.4s, v7.4s, v8.4s   // 1,5,3,7
+    trn2 v2.4s, v26.4s, v27.4s // 0,4,2,6
+    trn2 v3.4s, v7.4s, v8.4s   // 0,4,2,6
+
+    trn1 v10.4s, v2.4s, v0.4s // batch=1
+    trn2 v11.4s, v2.4s, v0.4s
+    trn1 v21.4s, v3.4s, v1.4s // batch=3
+    trn2 v19.4s, v3.4s, v1.4s
+
+    fcvtl v29.4s, v31.4h // oc:0-3
+    fcvtl2 v30.4s, v31.8h // oc:4-7
+    fcvtl v5.4s, v4.4h // scales: 4 batch
+
+    MulScale v24, v25, v10, v11, v5, 0, 1, v29, v30
+    MulScale v28, v6, v21, v19, v5, 2, 3, v29, v30
+    Float32ToHalf v24, v25, v10, v11, v12, v13
+    Float32ToHalf v28, v6, v21, v19, v14, v15
+Tile4Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.d}[0], [x22]  // sums
+    // sum + (zero * sumx) + bias
+    Dequant v12, v1, v2, v3, 0
+    Dequant v13, v1, v2, v3, 1
+    Dequant v14, v1, v2, v3, 2
+    Dequant v15, v1, v2, v3, 3
+    st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 8 * sizeof(float16_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 4 * sizeof(float16_t)
+    add x12, x12, #8    // scale += 4 * sizeof(float16_t)
+    b TILE_4
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    movi v6.4s, #0
+    movi v7.4s, #0
+    movi v8.4s, #0
+    movi v9.4s, #0
+    movi v10.4s, #0
+    movi v11.4s, #0
+    movi v12.4s, #0
+    movi v13.4s, #0
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v21.16b, v0.16b, #4
+    and v22.16b, v0.16b, v14.16b
+    sub v21.16b, v21.16b, v15.16b
+    sub v22.16b, v22.16b, v15.16b
+    ushr v23.16b, v1.16b, #4
+    and v24.16b, v1.16b, v14.16b
+    sub v23.16b, v23.16b, v15.16b
+    sub v24.16b, v24.16b, v15.16b
+    zip1 v0.16b, v21.16b, v22.16b
+    zip2 v1.16b, v21.16b, v22.16b
+    zip1 v2.16b, v23.16b, v24.16b
+    zip2 v3.16b, v23.16b, v24.16b
+    ld1 {v4.8b}, [x24], x15   // src
+    mov v31.d[0], v0.d[1]
+    mov v31.d[1], v0.d[0]
+    mov v30.d[0], v1.d[1]
+    mov v30.d[1], v1.d[0]
+    mov v29.d[0], v2.d[1]
+    mov v29.d[1], v2.d[0]
+    mov v28.d[0], v3.d[1]
+    mov v28.d[1], v3.d[0]
+
+
+    .inst 0x4e849406 // sdot v6.4s, v0.16b, v4.16b
+    .inst 0x4e8497e7 // sdot v7.4s, v31.16b, v4.16b
+    .inst 0x4e849428 // sdot v8.4s, v1.16b, v4.16b
+    .inst 0x4e8497c9 // sdot v9.4s, v30.16b, v4.16b
+    .inst 0x4e84944a // sdot v10.4s, v2.16b, v4.16b
+    .inst 0x4e8497ab // sdot v11.4s, v29.16b, v4.16b
+    .inst 0x4e84946c // sdot v12.4s, v3.16b, v4.16b
+    .inst 0x4e84978d // sdot v13.4s, v28.16b, v4.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+    addp v16.4s, v6.4s, v7.4s
+    addp v17.4s, v8.4s, v9.4s
+    addp v18.4s, v10.4s, v11.4s
+    addp v19.4s, v12.4s, v13.4s
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp1 v16.4s, v18.4s, v19.4s
+    scvtf v15.4s, v15.4s
+    scvtf v16.4s, v16.4s
+    // using float scale dequant for precison
+    ld1 {v4.h}[0], [x23]  // scales
+    ld1 {v0.8h}, [x19], #16  // alpha
+    fcvtl v5.4s, v4.4h
+    fmul v15.4s, v15.4s, v5.s[0]
+    fmul v16.4s, v16.4s, v5.s[0]
+    fcvtl v20.4s, v0.4h
+    fcvtl2 v21.4s, v0.8h
+    fmul v15.4s, v15.4s, v20.4s
+    fmul v16.4s, v16.4s, v21.4s
+    fcvtn v17.4h,  v15.4s
+    fcvtn2 v17.8h, v16.4s
+Tile1Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.h}[0], [x22]  // sums
+    // sum + (zero * sumx) + bias
+    fadd v2.8h, v2.8h, v17.8h
+    fmla v2.8h, v1.8h, v3.h[0]
+    st1 {v2.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 8 * sizeof(float16_t)
+    add x1, x1, #8      // dst += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #2   // sum += 1 * sizeof(float16_t)
+    add x12, x12, #2   // scale += 1 * sizeof(float16_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 894 - 0
source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_smmla.S

@@ -0,0 +1,894 @@
+//
+//  MNNGemmHybridInt4_smmla.S
+//  MNN
+//
+//  Created by MNN on 2023/10/30.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s
+    fmul \d0\().4s, \d0\().4s, \s\().4s
+    fmul \d1\().4s, \d1\().4s, \s\().4s
+    fmul \d2\().4s, \d2\().4s, \s\().4s
+    fmul \d3\().4s, \d3\().4s, \s\().4s
+.endm
+
+.macro Float32ToHalf s0, s1, s2, s3, d0, d1
+    fcvtn \d0\().4h,  \s0\().4s
+    fcvtn2 \d0\().8h, \s1\().4s
+    fcvtn \d1\().4h,  \s2\().4s
+    fcvtn2 \d1\().8h, \s3\().4s
+.endm
+
+.macro Dequant c0, a0, z0, b0, s0, idx
+    fmul \c0\().8h, \c0\().8h, \a0\().8h
+    fmla \c0\().8h, \z0\().8h, \s0\().h[\idx]
+    fadd \c0\().8h, \c0\().8h, \b0\().8h
+.endm
+
+asm_function MNNGemmHybridInt4FP16_smmla
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32  = src_depth_quad << 5
+b TILE_4
+
+TILE_12:
+    cmp x6, #12
+    blt TILE_10
+    sub x14, x4, #128  // dst_step
+    lsr x15, x14, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_12:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v8.4s, wzr
+    dup v9.4s, wzr
+    dup v10.4s, wzr
+    dup v11.4s, wzr
+    dup v12.4s, wzr
+    dup v13.4s, wzr
+    dup v14.4s, wzr
+    dup v15.4s, wzr
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v6.16b, w27
+    // offset
+    mov w27, #8
+    dup v7.16b, w27
+LoopSz_TILE_12:
+    // src    : 6 x [2 x 8] : (v4-5) * 3
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 6 x 4 x [4] : v8-v31
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v2.16b, v0.16b, #4
+    and v3.16b, v0.16b, v6.16b
+    ushr v4.16b, v1.16b, #4
+    and v5.16b, v1.16b, v6.16b
+    sub v2.16b, v2.16b, v7.16b
+    sub v3.16b, v3.16b, v7.16b
+    sub v4.16b, v4.16b, v7.16b
+    sub v5.16b, v5.16b, v7.16b
+    zip1 v0.16b, v2.16b, v3.16b
+    zip2 v1.16b, v2.16b, v3.16b
+    zip1 v2.16b, v4.16b, v5.16b
+    zip2 v3.16b, v4.16b, v5.16b
+
+    ld1 {v4.16b, v5.16b}, [x24], #32   // src
+    .inst 0x4e80a488 // smmla v8.4s, v4.16b, v0.16b
+    .inst 0x4e81a489 // smmla v9.4s, v4.16b, v1.16b
+    .inst 0x4e82a48a // smmla v10.4s, v4.16b, v2.16b
+    .inst 0x4e83a48b // smmla v11.4s, v4.16b, v3.16b
+    .inst 0x4e80a4ac // smmla v12.4s, v5.16b, v0.16b
+    .inst 0x4e81a4ad // smmla v13.4s, v5.16b, v1.16b
+    .inst 0x4e82a4ae // smmla v14.4s, v5.16b, v2.16b
+    .inst 0x4e83a4af // smmla v15.4s, v5.16b, v3.16b
+    ld1 {v4.16b, v5.16b}, [x24], #32   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    .inst 0x4e80a498 // smmla v24.4s, v4.16b, v0.16b
+    .inst 0x4e81a499 // smmla v25.4s, v4.16b, v1.16b
+    .inst 0x4e82a49a // smmla v26.4s, v4.16b, v2.16b
+    .inst 0x4e83a49b // smmla v27.4s, v4.16b, v3.16b
+    .inst 0x4e80a4bc // smmla v28.4s, v5.16b, v0.16b
+    .inst 0x4e81a4bd // smmla v29.4s, v5.16b, v1.16b
+    .inst 0x4e82a4be // smmla v30.4s, v5.16b, v2.16b
+    .inst 0x4e83a4bf // smmla v31.4s, v5.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_12
+
+LoopSzEnd_TILE_12:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat  v8,  v9, v10, v11
+    Int32ToFloat v12, v13, v14, v15
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    ld1 {v4.8h}, [x23], #16    // scales
+    ld1 {v5.d}[0], [x23], #8   // scales
+    fcvtl v6.4s, v4.4h
+    fcvtl2 v7.4s, v4.8h
+    // [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11]
+    dup v0.4s, v6.s[0]
+    mov v0.s[2], v6.s[1]
+    mov v0.s[3], v6.s[1]
+    dup v1.4s, v6.s[2]
+    mov v1.s[2], v6.s[3]
+    mov v1.s[3], v6.s[3]
+    dup v2.4s, v7.s[0]
+    mov v2.s[2], v7.s[1]
+    mov v2.s[3], v7.s[1]
+    dup v3.4s, v7.s[2]
+    mov v3.s[2], v7.s[3]
+    mov v3.s[3], v7.s[3]
+    fcvtl v7.4s, v5.4h
+    dup v4.4s, v7.s[0]
+    mov v4.s[2], v7.s[1]
+    mov v4.s[3], v7.s[1]
+    dup v5.4s, v7.s[2]
+    mov v5.s[2], v7.s[3]
+    mov v5.s[3], v7.s[3]
+    MulScale  v8,  v9, v10, v11, v0
+    MulScale v12, v13, v14, v15, v1
+    MulScale v16, v17, v18, v19, v2
+    MulScale v20, v21, v22, v23, v3
+    MulScale v24, v25, v26, v27, v4
+    MulScale v28, v29, v30, v31, v5
+    Float32ToHalf  v8,  v9, v10, v11,  v6,  v7
+    Float32ToHalf v12, v13, v14, v15,  v8,  v9
+    Float32ToHalf v16, v17, v18, v19, v10, v11
+    Float32ToHalf v20, v21, v22, v23, v12, v13
+    Float32ToHalf v24, v25, v26, v27, v14, v15
+    Float32ToHalf v28, v29, v30, v31, v16, v17
+    uzp1  v5.4s,  v6.4s,  v7.4s
+    uzp2  v6.4s,  v6.4s,  v7.4s
+    uzp1  v7.4s,  v8.4s,  v9.4s
+    uzp2  v8.4s,  v8.4s,  v9.4s
+    uzp1  v9.4s, v10.4s, v11.4s
+    uzp2 v10.4s, v10.4s, v11.4s
+    uzp1 v11.4s, v12.4s, v13.4s
+    uzp2 v12.4s, v12.4s, v13.4s
+    uzp1 v13.4s, v14.4s, v15.4s
+    uzp2 v14.4s, v14.4s, v15.4s
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp2 v16.4s, v16.4s, v17.4s
+Tile12Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.8h}, [x22], #16  // sums
+    ld1 {v4.d}[0], [x22], #8 // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant  v5, v0, v1, v2, v3, 0
+    Dequant  v6, v0, v1, v2, v3, 1
+    Dequant  v7, v0, v1, v2, v3, 2
+    Dequant  v8, v0, v1, v2, v3, 3
+    Dequant  v9, v0, v1, v2, v3, 4
+    Dequant v10, v0, v1, v2, v3, 5
+    Dequant v11, v0, v1, v2, v3, 6
+    Dequant v12, v0, v1, v2, v3, 7
+    Dequant v13, v0, v1, v2, v4, 0
+    Dequant v14, v0, v1, v2, v4, 1
+    Dequant v15, v0, v1, v2, v4, 2
+    Dequant v16, v0, v1, v2, v4, 3
+    st1 { v5.8h,  v6.8h,  v7.8h,  v8.8h}, [x17], #64
+    st1 { v9.8h, v10.8h, v11.8h, v12.8h}, [x17], #64
+    st1 {v13.8h, v14.8h, v15.8h, v16.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_12
+Tile12End:
+    sub x6, x6, #12     // bach -= 12
+    add x0, x0, #192    // dst += 12 * 8 * sizeof(float16_t)
+    add x1, x1, #96     // src += 12 * 8 * sizeof(int8_t)
+    add x11, x11, #24   // sum += 12 * sizeof(float16_t)
+    add x12, x12, #24   // scale += 12 * sizeof(float16_t)s
+    b TILE_12
+
+TILE_10:
+    cmp x6, #10
+    blt TILE_8
+    sub x14, x4, #128  // dst_step
+    lsr x15, x14, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_10:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v12.4s, wzr
+    dup v13.4s, wzr
+    dup v14.4s, wzr
+    dup v15.4s, wzr
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v10.16b, w27
+    // offset
+    mov w27, #8
+    dup v11.16b, w27
+LoopSz_TILE_10:
+    // src    : 5 x [2 x 8] : v4-8
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 5 x 4 x [4] : v12-v31
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v4.16b, v0.16b, #4
+    and v5.16b, v0.16b, v10.16b
+    sub v4.16b, v4.16b, v11.16b
+    sub v5.16b, v5.16b, v11.16b
+    ushr v6.16b, v1.16b, #4
+    and v7.16b, v1.16b, v10.16b
+    sub v6.16b, v6.16b, v11.16b
+    sub v7.16b, v7.16b, v11.16b
+    zip1 v0.16b, v4.16b, v5.16b
+    zip2 v1.16b, v4.16b, v5.16b
+    zip1 v2.16b, v6.16b, v7.16b
+    zip2 v3.16b, v6.16b, v7.16b
+    ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], #64   // src
+    ld1 {v8.16b}, [x24], x15   // src
+
+    .inst 0x4e80a48c // smmla v12.4s, v4.16b, v0.16b
+    .inst 0x4e81a48d // smmla v13.4s, v4.16b, v1.16b
+    .inst 0x4e82a48e // smmla v14.4s, v4.16b, v2.16b
+    .inst 0x4e83a48f // smmla v15.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b0 // smmla v16.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b1 // smmla v17.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b2 // smmla v18.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b3 // smmla v19.4s, v5.16b, v3.16b
+    .inst 0x4e80a4d4 // smmla v20.4s, v6.16b, v0.16b
+    .inst 0x4e81a4d5 // smmla v21.4s, v6.16b, v1.16b
+    .inst 0x4e82a4d6 // smmla v22.4s, v6.16b, v2.16b
+    .inst 0x4e83a4d7 // smmla v23.4s, v6.16b, v3.16b
+    .inst 0x4e80a4f8 // smmla v24.4s, v7.16b, v0.16b
+    .inst 0x4e81a4f9 // smmla v25.4s, v7.16b, v1.16b
+    .inst 0x4e82a4fa // smmla v26.4s, v7.16b, v2.16b
+    .inst 0x4e83a4fb // smmla v27.4s, v7.16b, v3.16b
+    .inst 0x4e80a51c // smmla v28.4s, v8.16b, v0.16b
+    .inst 0x4e81a51d // smmla v29.4s, v8.16b, v1.16b
+    .inst 0x4e82a51e // smmla v30.4s, v8.16b, v2.16b
+    .inst 0x4e83a51f // smmla v31.4s, v8.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_10
+
+LoopSzEnd_TILE_10:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v12, v13, v14, v15
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    ld1 {v4.8h}, [x23], #16    // scales
+    ld1 {v5.s}[0], [x23], #4   // scales
+    fcvtl v6.4s, v4.4h
+    fcvtl2 v7.4s, v4.8h
+    fcvtl v8.4s, v5.4h
+    // [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
+    dup v0.4s, v6.s[0]
+    mov v0.s[2], v6.s[1]
+    mov v0.s[3], v6.s[1]
+    dup v1.4s, v6.s[2]
+    mov v1.s[2], v6.s[3]
+    mov v1.s[3], v6.s[3]
+    dup v2.4s, v7.s[0]
+    mov v2.s[2], v7.s[1]
+    mov v2.s[3], v7.s[1]
+    dup v3.4s, v7.s[2]
+    mov v3.s[2], v7.s[3]
+    mov v3.s[3], v7.s[3]
+    dup v4.4s, v8.s[0]
+    mov v4.s[2], v8.s[1]
+    mov v4.s[3], v8.s[1]
+    MulScale v12, v13, v14, v15, v0
+    MulScale v16, v17, v18, v19, v1
+    MulScale v20, v21, v22, v23, v2
+    MulScale v24, v25, v26, v27, v3
+    MulScale v28, v29, v30, v31, v4
+    Float32ToHalf v12, v13, v14, v15, v10, v11
+    Float32ToHalf v16, v17, v18, v19, v12, v13
+    Float32ToHalf v20, v21, v22, v23, v14, v15
+    Float32ToHalf v24, v25, v26, v27, v16, v17
+    Float32ToHalf v28, v29, v30, v31, v18, v19
+    uzp1  v9.4s, v10.4s, v11.4s
+    uzp2 v10.4s, v10.4s, v11.4s
+    uzp1 v11.4s, v12.4s, v13.4s
+    uzp2 v12.4s, v12.4s, v13.4s
+    uzp1 v13.4s, v14.4s, v15.4s
+    uzp2 v14.4s, v14.4s, v15.4s
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp2 v16.4s, v16.4s, v17.4s
+    uzp1 v17.4s, v18.4s, v19.4s
+    uzp2 v18.4s, v18.4s, v19.4s
+Tile10Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.8h}, [x22], #16  // sums
+    ld1 {v4.s}[0], [x22], #4 // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant  v9, v0, v1, v2, v3, 0
+    Dequant v10, v0, v1, v2, v3, 1
+    Dequant v11, v0, v1, v2, v3, 2
+    Dequant v12, v0, v1, v2, v3, 3
+    Dequant v13, v0, v1, v2, v3, 4
+    Dequant v14, v0, v1, v2, v3, 5
+    Dequant v15, v0, v1, v2, v3, 6
+    Dequant v16, v0, v1, v2, v3, 7
+    Dequant v17, v0, v1, v2, v4, 0
+    Dequant v18, v0, v1, v2, v4, 1
+    st1 { v9.8h, v10.8h, v11.8h, v12.8h}, [x17], #64
+    st1 {v13.8h, v14.8h, v15.8h, v16.8h}, [x17], #64
+    st1 {v17.8h, v18.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_10
+Tile10End:
+    sub x6, x6, #10     // bach -= 10
+    add x0, x0, #160    // dst += 10 * 8 * sizeof(float16_t)
+    add x1, x1, #80     // src += 10 * 8 * sizeof(int8_t)
+    add x11, x11, #20   // sum += 10 * sizeof(float16_t)
+    add x12, x12, #20   // scale += 10 * sizeof(float16_t)
+    b TILE_10
+
+TILE_8:
+    cmp x6, #8
+    blt TILE_1
+    sub x14, x4, #64  // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_8:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v10.16b, w27
+    // offset
+    mov w27, #8
+    dup v11.16b, w27
+LoopSz_TILE_8:
+    // src    : 4 x [2 x 8] : v4-7
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 4 x 4 x [4] : v16-v31
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v4.16b, v0.16b, #4
+    and v5.16b, v0.16b, v10.16b
+    sub v4.16b, v4.16b, v11.16b
+    sub v5.16b, v5.16b, v11.16b
+    ushr v6.16b, v1.16b, #4
+    and v7.16b, v1.16b, v10.16b
+    sub v6.16b, v6.16b, v11.16b
+    sub v7.16b, v7.16b, v11.16b
+    zip1 v0.16b, v4.16b, v5.16b
+    zip2 v1.16b, v4.16b, v5.16b
+    zip1 v2.16b, v6.16b, v7.16b
+    zip2 v3.16b, v6.16b, v7.16b
+    ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b
+    .inst 0x4e80a4d8 // smmla v24.4s, v6.16b, v0.16b
+    .inst 0x4e81a4d9 // smmla v25.4s, v6.16b, v1.16b
+    .inst 0x4e82a4da // smmla v26.4s, v6.16b, v2.16b
+    .inst 0x4e83a4db // smmla v27.4s, v6.16b, v3.16b
+    .inst 0x4e80a4fc // smmla v28.4s, v7.16b, v0.16b
+    .inst 0x4e81a4fd // smmla v29.4s, v7.16b, v1.16b
+    .inst 0x4e82a4fe // smmla v30.4s, v7.16b, v2.16b
+    .inst 0x4e83a4ff // smmla v31.4s, v7.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_8
+
+LoopSzEnd_TILE_8:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    ld1 {v4.8h}, [x23]  // scales
+    fcvtl v5.4s, v4.4h
+    fcvtl2 v6.4s, v4.8h
+    dup v0.4s, v5.s[0]
+    mov v0.s[2], v5.s[1]
+    mov v0.s[3], v5.s[1]
+    dup v1.4s, v5.s[2]
+    mov v1.s[2], v5.s[3]
+    mov v1.s[3], v5.s[3]
+    dup v2.4s, v6.s[0]
+    mov v2.s[2], v6.s[1]
+    mov v2.s[3], v6.s[1]
+    dup v3.4s, v6.s[2]
+    mov v3.s[2], v6.s[3]
+    mov v3.s[3], v6.s[3]
+    MulScale v16, v17, v18, v19, v0
+    MulScale v20, v21, v22, v23, v1
+    MulScale v24, v25, v26, v27, v2
+    MulScale v28, v29, v30, v31, v3
+    Float32ToHalf v16, v17, v18, v19, v12, v13
+    Float32ToHalf v20, v21, v22, v23, v14, v15
+    Float32ToHalf v24, v25, v26, v27, v16, v17
+    Float32ToHalf v28, v29, v30, v31, v18, v19
+    uzp1 v11.4s, v12.4s, v13.4s
+    uzp2 v12.4s, v12.4s, v13.4s
+    uzp1 v13.4s, v14.4s, v15.4s
+    uzp2 v14.4s, v14.4s, v15.4s
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp2 v16.4s, v16.4s, v17.4s
+    uzp1 v17.4s, v18.4s, v19.4s
+    uzp2 v18.4s, v18.4s, v19.4s
+Tile8Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.8h}, [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v11, v0, v1, v2, v3, 0
+    Dequant v12, v0, v1, v2, v3, 1
+    Dequant v13, v0, v1, v2, v3, 2
+    Dequant v14, v0, v1, v2, v3, 3
+    Dequant v15, v0, v1, v2, v3, 4
+    Dequant v16, v0, v1, v2, v3, 5
+    Dequant v17, v0, v1, v2, v3, 6
+    Dequant v18, v0, v1, v2, v3, 7
+    st1 {v11.8h, v12.8h, v13.8h, v14.8h}, [x17], #64
+    st1 {v15.8h, v16.8h, v17.8h, v18.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_8
+Tile8End:
+    sub x6, x6, #8      // bach -= 8
+    add x0, x0, #128    // dst += 8 * 8 * sizeof(float16_t)
+    add x1, x1, #64     // src += 8 * 8 * sizeof(int8_t)
+    add x11, x11, #16   // sum += 8 * sizeof(float16_t)
+    add x12, x12, #16   // scale += 8 * sizeof(float16_t)
+    b TILE_8
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_2
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v10.16b, w27
+    // offset
+    mov w27, #8
+    dup v11.16b, w27
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v4.16b, v0.16b, #4
+    and v5.16b, v0.16b, v10.16b
+    sub v4.16b, v4.16b, v11.16b
+    sub v5.16b, v5.16b, v11.16b
+    ushr v6.16b, v1.16b, #4
+    and v7.16b, v1.16b, v10.16b
+    sub v6.16b, v6.16b, v11.16b
+    sub v7.16b, v7.16b, v11.16b
+    zip1 v0.16b, v4.16b, v5.16b
+    zip2 v1.16b, v4.16b, v5.16b
+    zip1 v2.16b, v6.16b, v7.16b
+    zip2 v3.16b, v6.16b, v7.16b
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v4.d}[0], [x23]  // scales
+    fcvtl v5.4s, v4.4h
+    dup v0.4s, v5.s[0]
+    mov v0.s[2], v5.s[1]
+    mov v0.s[3], v5.s[1]
+    dup v1.4s, v5.s[2]
+    mov v1.s[2], v5.s[3]
+    mov v1.s[3], v5.s[3]
+    MulScale v16, v17, v18, v19, v0
+    MulScale v20, v21, v22, v23, v1
+    Float32ToHalf v16, v17, v18, v19, v12, v13
+    Float32ToHalf v20, v21, v22, v23, v14, v15
+    uzp1 v11.4s, v12.4s, v13.4s
+    uzp2 v12.4s, v12.4s, v13.4s
+    uzp1 v13.4s, v14.4s, v15.4s
+    uzp2 v14.4s, v14.4s, v15.4s
+Tile4Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.d}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v11, v0, v1, v2, v3, 0
+    Dequant v12, v0, v1, v2, v3, 1
+    Dequant v13, v0, v1, v2, v3, 2
+    Dequant v14, v0, v1, v2, v3, 3
+    st1 {v11.8h, v12.8h, v13.8h, v14.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 8 * sizeof(float16_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 4 * sizeof(float16_t)
+    add x12, x12, #8    // scale += 4 * sizeof(float16_t)
+    b TILE_4
+
+TILE_2:
+    cmp x6, #2
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_2:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_2:
+    // src    : 1 x [2 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [4] : v16-19
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    ushr v10.16b, v1.16b, #4
+    and v11.16b, v1.16b, v14.16b
+    sub v10.16b, v10.16b, v15.16b
+    sub v11.16b, v11.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    zip2 v1.16b, v8.16b, v9.16b
+    zip1 v2.16b, v10.16b, v11.16b
+    zip2 v3.16b, v10.16b, v11.16b
+    ld1 {v4.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_2
+
+LoopSzEnd_TILE_2:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v13.2d, v16.2d, v17.2d
+    uzp1 v14.2d, v18.2d, v19.2d
+    uzp2 v15.2d, v16.2d, v17.2d
+    uzp2 v16.2d, v18.2d, v19.2d
+    Int32ToFloat v13, v14, v15, v16
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    fcvtl v5.4s, v4.4h
+    fmul v13.4s, v13.4s, v5.s[0]
+    fmul v14.4s, v14.4s, v5.s[0]
+    fmul v15.4s, v15.4s, v5.s[1]
+    fmul v16.4s, v16.4s, v5.s[1]
+    fcvtn v12.4h,  v13.4s
+    fcvtn2 v12.8h, v14.4s
+    fcvtn v13.4h,  v15.4s
+    fcvtn2 v13.8h, v16.4s
+Tile2Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v12, v0, v1, v2, v3, 0
+    Dequant v13, v0, v1, v2, v3, 1
+    st1 {v12.8h, v13.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_2
+Tile2End:
+    sub x6, x6, #2      // batch -= 2
+    add x0, x0, #32     // dst += 2 * 8 * sizeof(float16_t)
+    add x1, x1, #16     // dst += 2 * 8 * sizeof(int8_t)
+    add x11, x11, #4    // sum += 2 * sizeof(float16_t)
+    add x12, x12, #4    // scale += 2 * sizeof(float16_t)
+    b TILE_2
+
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    prfm pldl1keep, [x25, #64]   // 预取下一次权重数据
+    prfm pldl1keep, [x24, x15]   // 预取下一次源数据
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    ushr v10.16b, v1.16b, #4
+    and v11.16b, v1.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    sub v10.16b, v10.16b, v15.16b
+    sub v11.16b, v11.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    zip2 v1.16b, v8.16b, v9.16b
+    zip1 v2.16b, v10.16b, v11.16b
+    zip2 v3.16b, v10.16b, v11.16b
+    ld1 {v4.8b}, [x24], x15   // src
+    .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
+    .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
+    .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
+    .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp1 v16.4s, v18.4s, v19.4s
+    scvtf v15.4s, v15.4s
+    scvtf v16.4s, v16.4s
+    // using float scale dequant for precison
+    ld1 {v4.h}[0], [x23]  // scales
+    fcvtl v5.4s, v4.4h
+    fmul v15.4s, v15.4s, v5.s[0]
+    fmul v16.4s, v16.4s, v5.s[0]
+    fcvtn v17.4h,  v15.4s
+    fcvtn2 v17.8h, v16.4s
+Tile1Dequant:
+    ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.h}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    fmla v2.8h, v0.8h, v17.8h
+    fmla v2.8h, v1.8h, v3.h[0]
+    st1 {v2.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 8 * sizeof(float16_t)
+    add x1, x1, #8      // dst += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #2   // sum += 1 * sizeof(float16_t)
+    add x12, x12, #2   // scale += 1 * sizeof(float16_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 323 - 0
source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_sdot.S

@@ -0,0 +1,323 @@
+//
+//  MNNGemmHybridInt8_sdot.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1
+    fmul \d0\().4s, \d0\().4s, \s\().s[\idx0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[\idx0]
+    fmul \d2\().4s, \d2\().4s, \s\().s[\idx1]
+    fmul \d3\().4s, \d3\().4s, \s\().s[\idx1]
+    fmul \d0\().4s, \d0\().4s, \alpha0\().4s
+    fmul \d1\().4s, \d1\().4s, \alpha1\().4s
+    fmul \d2\().4s, \d2\().4s, \alpha0\().4s
+    fmul \d3\().4s, \d3\().4s, \alpha1\().4s
+.endm
+
+.macro Float32ToHalf s0, s1, s2, s3, d0, d1
+    fcvtn \d0\().4h,  \s0\().4s
+    fcvtn2 \d0\().8h, \s1\().4s
+    fcvtn \d1\().4h,  \s2\().4s
+    fcvtn2 \d1\().8h, \s3\().4s
+.endm
+
+.macro Dequant c0, z0, b0, s0, idx
+    fmla \c0\().8h, \z0\().8h, \s0\().h[\idx]
+    fadd \c0\().8h, \c0\().8h, \b0\().8h
+.endm
+
+asm_function MNNGemmHybridInt8FP16_sdot
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt8_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64  = src_depth_quad << 6
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    mov v10.d[0], v0.d[1]
+    mov v10.d[1], v0.d[0]
+    mov v11.d[1], v1.d[0]
+    mov v11.d[0], v1.d[1]
+    mov v12.d[0], v2.d[1]
+    mov v12.d[1], v2.d[0]
+    mov v13.d[0], v3.d[1]
+    mov v13.d[1], v3.d[0]
+    .inst 0x4e809490 // sdot v16.4s, v4.16b, v0.16b
+    .inst 0x4e8a9498 // sdot v24.4s, v4.16b, v10.16b
+    .inst 0x4e819491 // sdot v17.4s, v4.16b, v1.16b
+    .inst 0x4e8b9499 // sdot v25.4s, v4.16b, v11.16b
+    .inst 0x4e829492 // sdot v18.4s, v4.16b, v2.16b
+    .inst 0x4e8c949a // sdot v26.4s, v4.16b, v12.16b
+    .inst 0x4e839493 // sdot v19.4s, v4.16b, v3.16b
+    .inst 0x4e8d949b // sdot v27.4s, v4.16b, v13.16b
+    .inst 0x4e8094b4 // sdot v20.4s, v5.16b, v0.16b
+    .inst 0x4e8a94bc // sdot v28.4s, v5.16b, v10.16b
+    .inst 0x4e8194b5 // sdot v21.4s, v5.16b, v1.16b
+    .inst 0x4e8b94bd // sdot v29.4s, v5.16b, v11.16b
+    .inst 0x4e8294b6 // sdot v22.4s, v5.16b, v2.16b
+    .inst 0x4e8c94be // sdot v30.4s, v5.16b, v12.16b
+    .inst 0x4e8394b7 // sdot v23.4s, v5.16b, v3.16b
+    .inst 0x4e8d94bf // sdot v31.4s, v5.16b, v13.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+    addp v16.4s, v16.4s, v24.4s
+    addp v17.4s, v17.4s, v25.4s
+    addp v18.4s, v18.4s, v26.4s
+    addp v19.4s, v19.4s, v27.4s
+    addp v20.4s, v20.4s, v28.4s
+    addp v21.4s, v21.4s, v29.4s
+    addp v22.4s, v22.4s, v30.4s
+    addp v23.4s, v23.4s, v31.4s
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v4.d}[0], [x23]  // scales
+    ld1 {v31.8h}, [x19], #16  // alpha
+    uzp1 v24.4s, v16.4s, v17.4s // batch=0,oc:0-3
+    uzp2 v26.4s, v16.4s, v17.4s // batch=1,oc:1,0,3,2
+    uzp1 v25.4s, v18.4s, v19.4s // batch=0,oc:4-7
+    uzp2 v27.4s, v18.4s, v19.4s // batch=1,oc:5,4,7,6
+
+    uzp1 v28.4s, v20.4s, v21.4s // batch=2,oc:0-3
+    uzp2 v7.4s, v20.4s, v21.4s  // batch=3,oc:1,0,3,2
+    uzp1 v6.4s, v22.4s, v23.4s  // batch=2,oc:4-7
+    uzp2 v8.4s, v22.4s, v23.4s  // batch=3,oc:5,4,7,6
+
+    trn1 v0.4s, v26.4s, v27.4s // 1,5,3,7
+    trn1 v1.4s, v7.4s, v8.4s   // 1,5,3,7
+    trn2 v2.4s, v26.4s, v27.4s // 0,4,2,6
+    trn2 v3.4s, v7.4s, v8.4s   // 0,4,2,6
+
+    trn1 v10.4s, v2.4s, v0.4s // batch=1
+    trn2 v11.4s, v2.4s, v0.4s
+    trn1 v21.4s, v3.4s, v1.4s // batch=3
+    trn2 v19.4s, v3.4s, v1.4s
+
+    fcvtl v29.4s, v31.4h // oc:0-3
+    fcvtl2 v30.4s, v31.8h // oc:4-7
+    fcvtl v5.4s, v4.4h // scales: 4 batch
+
+    MulScale v24, v25, v10, v11, v5, 0, 1, v29, v30
+    MulScale v28, v6, v21, v19, v5, 2, 3, v29, v30
+    Float32ToHalf v24, v25, v10, v11, v12, v13
+    Float32ToHalf v28, v6, v21, v19, v14, v15
+Tile4Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.d}[0], [x22]  // sums
+    // sum + (zero * sumx) + bias
+    Dequant v12, v1, v2, v3, 0
+    Dequant v13, v1, v2, v3, 1
+    Dequant v14, v1, v2, v3, 2
+    Dequant v15, v1, v2, v3, 3
+    st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 8 * sizeof(float16_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 4 * sizeof(float16_t)
+    add x12, x12, #8    // scale += 4 * sizeof(float16_t)
+    b TILE_4
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    movi v6.4s, #0
+    movi v7.4s, #0
+    movi v8.4s, #0
+    movi v9.4s, #0
+    movi v10.4s, #0
+    movi v11.4s, #0
+    movi v12.4s, #0
+    movi v13.4s, #0
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.8b}, [x24], x15   // src
+    mov v31.d[0], v0.d[1]
+    mov v31.d[1], v0.d[0]
+    mov v30.d[0], v1.d[1]
+    mov v30.d[1], v1.d[0]
+    mov v29.d[0], v2.d[1]
+    mov v29.d[1], v2.d[0]
+    mov v28.d[0], v3.d[1]
+    mov v28.d[1], v3.d[0]
+
+
+    .inst 0x4e849406 // sdot v6.4s, v0.16b, v4.16b
+    .inst 0x4e8497e7 // sdot v7.4s, v31.16b, v4.16b
+    .inst 0x4e849428 // sdot v8.4s, v1.16b, v4.16b
+    .inst 0x4e8497c9 // sdot v9.4s, v30.16b, v4.16b
+    .inst 0x4e84944a // sdot v10.4s, v2.16b, v4.16b
+    .inst 0x4e8497ab // sdot v11.4s, v29.16b, v4.16b
+    .inst 0x4e84946c // sdot v12.4s, v3.16b, v4.16b
+    .inst 0x4e84978d // sdot v13.4s, v28.16b, v4.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+    addp v16.4s, v6.4s, v7.4s
+    addp v17.4s, v8.4s, v9.4s
+    addp v18.4s, v10.4s, v11.4s
+    addp v19.4s, v12.4s, v13.4s
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v15.4s, v16.4s, v17.4s
+    uzp1 v16.4s, v18.4s, v19.4s
+    scvtf v15.4s, v15.4s
+    scvtf v16.4s, v16.4s
+    // using float scale dequant for precison
+    ld1 {v4.h}[0], [x23]  // scales
+    ld1 {v0.8h}, [x19], #16  // alpha
+    fcvtl v5.4s, v4.4h
+    fmul v15.4s, v15.4s, v5.s[0]
+    fmul v16.4s, v16.4s, v5.s[0]
+    fcvtl v20.4s, v0.4h
+    fcvtl2 v21.4s, v0.8h
+    fmul v15.4s, v15.4s, v20.4s
+    fmul v16.4s, v16.4s, v21.4s
+    fcvtn v17.4h,  v15.4s
+    fcvtn2 v17.8h, v16.4s
+Tile1Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.h}[0], [x22]  // sums
+    // sum + (zero * sumx) + bias
+    fadd v2.8h, v2.8h, v17.8h
+    fmla v2.8h, v1.8h, v3.h[0]
+    st1 {v2.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 8 * sizeof(float16_t)
+    add x1, x1, #8      // dst += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #2   // sum += 1 * sizeof(float16_t)
+    add x12, x12, #2   // scale += 1 * sizeof(float16_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 566 - 0
source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_smmla.S

@@ -0,0 +1,566 @@
+//
+//  MNNGemmHybridInt8_smmla.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1
+    fmul \d0\().4s, \d0\().4s, \s\().s[\idx0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[\idx0]
+    fmul \d2\().4s, \d2\().4s, \s\().s[\idx1]
+    fmul \d3\().4s, \d3\().4s, \s\().s[\idx1]
+    fmul \d0\().4s, \d0\().4s, \alpha0\().4s
+    fmul \d1\().4s, \d1\().4s, \alpha1\().4s
+    fmul \d2\().4s, \d2\().4s, \alpha0\().4s
+    fmul \d3\().4s, \d3\().4s, \alpha1\().4s
+.endm
+
+.macro Float32ToHalf s0, s1, s2, s3, d0, d1
+    fcvtn \d0\().4h,  \s0\().4s
+    fcvtn2 \d0\().8h, \s1\().4s
+    fcvtn \d1\().4h,  \s2\().4s
+    fcvtn2 \d1\().8h, \s3\().4s
+.endm
+
+.macro Dequant c0, z0, b0, s0, idx
+    fmla \c0\().8h, \z0\().8h, \s0\().h[\idx]
+    fadd \c0\().8h, \c0\().8h, \b0\().8h
+.endm
+
+asm_function MNNGemmHybridInt8FP16_smmla
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt8_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64  = src_depth_quad << 6
+cmp x6, #1
+beq TILE_EQ_1
+
+TILE_8:
+    cmp x6, #8
+    blt TILE_4
+    //mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    sub x14, x4, #64
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_8:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    dup v24.4s, wzr
+    dup v25.4s, wzr
+    dup v26.4s, wzr
+    dup v27.4s, wzr
+    dup v28.4s, wzr
+    dup v29.4s, wzr
+    dup v30.4s, wzr
+    dup v31.4s, wzr
+LoopSz_TILE_8:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b // batch=0,1, oc=4,5
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b // batch=0,1, oc=6,7
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b // batch=2,3, oc=0,1
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b // batch=2,3, oc=2,3
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b // batch=2,3, oc=4,5
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b // batch=2,3, oc=6,7
+
+    .inst 0x4e80a4d8 // smmla v24.4s, v6.16b, v0.16b // batch=4,5, oc=0,1
+    .inst 0x4e81a4d9 // smmla v25.4s, v6.16b, v1.16b // batch=4,5, oc=2,3
+    .inst 0x4e82a4da // smmla v26.4s, v6.16b, v2.16b // batch=4,5, oc=4,5
+    .inst 0x4e83a4db // smmla v27.4s, v6.16b, v3.16b // batch=4,5, oc=6,7
+    .inst 0x4e80a4fc // smmla v28.4s, v7.16b, v0.16b // batch=6,7, oc=0,1
+    .inst 0x4e81a4fd // smmla v29.4s, v7.16b, v1.16b // batch=6,7, oc=2,3
+    .inst 0x4e82a4fe // smmla v30.4s, v7.16b, v2.16b // batch=6,7, oc=4,5
+    .inst 0x4e83a4ff // smmla v31.4s, v7.16b, v3.16b // batch=6,7, oc=6,7
+    subs x26, x26, #1
+    bne LoopSz_TILE_8
+
+LoopSzEnd_TILE_8:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    trn1 v8.2d,  v16.2d, v17.2d // batch=0,oc:0-3
+    trn1 v9.2d,  v18.2d, v19.2d // batch=0,oc:4-7
+    trn2 v10.2d, v16.2d, v17.2d // batch=1,oc:0-3
+    trn2 v11.2d, v18.2d, v19.2d // batch=1,oc:4-7
+    trn1 v12.2d, v20.2d, v21.2d // batch=2,oc:0-3
+    trn1 v13.2d, v22.2d, v23.2d // batch=2,oc:4-7
+    trn2 v14.2d, v20.2d, v21.2d // batch=3,oc:0-3
+    trn2 v15.2d, v22.2d, v23.2d // batch=3,oc:4-7
+
+    trn1 v0.2d, v24.2d, v25.2d // batch=4,oc:0-3
+    trn1 v1.2d, v26.2d, v27.2d // batch=4,oc:4-7
+    trn2 v2.2d, v24.2d, v25.2d // batch=5,oc:0-3
+    trn2 v3.2d, v26.2d, v27.2d // batch=5,oc:4-7
+    trn1 v4.2d, v28.2d, v29.2d // batch=6,oc:0-3
+    trn1 v5.2d, v30.2d, v31.2d // batch=6,oc:4-7
+    trn2 v6.2d, v28.2d, v29.2d // batch=7,oc:0-3
+    trn2 v7.2d, v30.2d, v31.2d // batch=7,oc:4-7
+
+    ld1 {v16.8h}, [x23]  // scales
+    ld1 {v17.8h}, [x19], #16  // alpha
+    fcvtl v18.4s, v17.4h // oc:0-3
+    fcvtl2 v19.4s, v17.8h // oc:4-7
+    fcvtl v28.4s, v16.4h // scales: batch 0,1,2,3
+    fcvtl2 v29.4s, v16.8h // scales: batch 4,5,6,7
+
+    MulScale v8, v9, v10, v11, v28, 0, 1, v18, v19
+    MulScale v12, v13, v14, v15, v28, 2, 3, v18, v19
+    Float32ToHalf v8, v9, v10, v11, v20, v21 // batch=0,1
+    Float32ToHalf v12, v13, v14, v15, v22, v23 // batch=2,3
+
+    MulScale v0, v1, v2, v3, v29, 0, 1, v18, v19
+    MulScale v4, v5, v6, v7, v29, 2, 3, v18, v19
+    Float32ToHalf v0, v1, v2, v3, v24, v25 // batch=4,5
+    Float32ToHalf v4, v5, v6, v7, v26, v27 // batch=6,7
+
+Tile8Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.8h}, [x22]  // sums
+    // sum + (zero * sumx) + bias
+    Dequant v20, v1, v2, v3, 0
+    Dequant v21, v1, v2, v3, 1
+    Dequant v22, v1, v2, v3, 2
+    Dequant v23, v1, v2, v3, 3
+
+    Dequant v24, v1, v2, v3, 4
+    Dequant v25, v1, v2, v3, 5
+    Dequant v26, v1, v2, v3, 6
+    Dequant v27, v1, v2, v3, 7
+    st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x17], #64
+    st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_8
+Tile8End:
+    sub x6, x6, #8      // bach -= 8
+    add x0, x0, #128     // dst += 8 * 8 * sizeof(float16_t)
+    add x1, x1, #64     // src += 8 * 8 * sizeof(int8_t)
+    add x11, x11, #16    // sum += 8 * sizeof(float16_t)
+    add x12, x12, #16    // scale += 8 * sizeof(float16_t)
+    b TILE_8
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_2
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b // batch=0,1, oc=4,5
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b // batch=0,1, oc=6,7
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b // batch=2,3, oc=0,1
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b // batch=2,3, oc=2,3
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b // batch=2,3, oc=4,5
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b // batch=2,3, oc=6,7
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v4.d}[0], [x23]  // scales
+    ld1 {v31.8h}, [x19], #16  // alpha
+    fcvtl v29.4s, v31.4h // oc:0-3
+    fcvtl2 v30.4s, v31.8h // oc:4-7
+    trn1 v24.2d, v16.2d, v17.2d // batch=0,oc:0-3
+    trn1 v25.2d, v18.2d, v19.2d // batch=0,oc:4-7
+    trn2 v26.2d, v16.2d, v17.2d // batch=1,oc:0-3
+    trn2 v27.2d, v18.2d, v19.2d // batch=1,oc:4-7
+    trn1 v28.2d, v20.2d, v21.2d // batch=2,oc:0-3
+    trn1 v6.2d, v22.2d, v23.2d  // batch=2,oc:4-7
+    trn2 v7.2d, v20.2d, v21.2d  // batch=3,oc:0-3
+    trn2 v8.2d, v22.2d, v23.2d  // batch=3,oc:4-7
+
+    fcvtl v5.4s, v4.4h // scales: 4 batch
+
+    MulScale v24, v25, v26, v27, v5, 0, 1, v29, v30
+    MulScale v28, v6, v7, v8, v5, 2, 3, v29, v30
+    Float32ToHalf v24, v25, v26, v27, v12, v13
+    Float32ToHalf v28, v6, v7, v8, v14, v15
+Tile4Dequant:
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.d}[0], [x22]  // sums
+    // sum + (zero * sumx) + bias
+    Dequant v12, v1, v2, v3, 0
+    Dequant v13, v1, v2, v3, 1
+    Dequant v14, v1, v2, v3, 2
+    Dequant v15, v1, v2, v3, 3
+    st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 8 * sizeof(float16_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 4 * sizeof(float16_t)
+    add x12, x12, #8    // scale += 4 * sizeof(float16_t)
+    b TILE_4
+
+TILE_2:
+    cmp x6, #2
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_2:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+LoopSz_TILE_2:
+    // src    : 1 x [2 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [4] : v16-19
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_2
+
+LoopSzEnd_TILE_2:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v13.2d, v16.2d, v17.2d
+    uzp1 v14.2d, v18.2d, v19.2d
+    uzp2 v15.2d, v16.2d, v17.2d
+    uzp2 v16.2d, v18.2d, v19.2d
+    Int32ToFloat v13, v14, v15, v16
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    ld1 {v0.8h}, [x19], #16  // alpha
+    fcvtl v5.4s, v4.4h
+    fcvtl v20.4s, v0.4h
+    fcvtl2 v21.4s, v0.8h
+    MulScale v13, v14, v15, v16, v5, 0, 1, v20, v21
+    fcvtn v11.4h,  v13.4s
+    fcvtn2 v11.8h, v14.4s
+    fcvtn v12.4h,  v15.4s
+    fcvtn2 v12.8h, v16.4s
+Tile2Dequant:
+    //ld1 {v0.8h}, [x19], #16  // alpha
+    ld1 {v1.8h}, [x20], #16  // zero
+    ld1 {v2.8h}, [x21], #16  // bias
+    ld1 {v3.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v11, v1, v2, v3, 0
+    Dequant v12, v1, v2, v3, 1
+    st1 {v11.8h, v12.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_2
+Tile2End:
+    sub x6, x6, #2      // batch -= 2
+    add x0, x0, #32     // dst += 2 * 8 * sizeof(float16_t)
+    add x1, x1, #16     // dst += 2 * 8 * sizeof(int8_t)
+    add x11, x11, #4    // sum += 2 * sizeof(float16_t)
+    add x12, x12, #4    // scale += 2 * sizeof(float16_t)
+    b TILE_2
+
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    ld1 {v29.8h}, [x20], #16  // zero
+    ld1 {v30.8h}, [x21], #16  // bias
+    ld1 {v8.h}[0], [x22]  // sums
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    fmla v30.8h, v29.8h, v8.h[0] // bias + zero * sum
+
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.8b}, [x24], x15   // src
+    .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
+    .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
+    .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
+    .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v22.4s, v16.4s, v17.4s
+    uzp1 v23.4s, v18.4s, v19.4s
+    scvtf v22.4s, v22.4s
+    scvtf v23.4s, v23.4s
+    // using float scale dequant for precison
+    ld1 {v4.h}[0], [x23]  // scales
+    ld1 {v0.8h}, [x19], #16  // alpha
+    fcvtl v5.4s, v4.4h
+    fcvtl v20.4s, v0.4h
+    fcvtl2 v21.4s, v0.8h
+
+    fmul v22.4s, v22.4s, v5.s[0]
+    fmul v23.4s, v23.4s, v5.s[0]
+    fmul v22.4s, v22.4s, v20.4s
+    fmul v23.4s, v23.4s, v21.4s
+    fcvtn v17.4h,  v22.4s
+    fcvtn2 v17.8h, v23.4s
+Tile1Dequant:
+    // sum + (zero * sumx) + bias
+    fadd v30.8h, v30.8h, v17.8h
+    st1 {v30.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 8 * sizeof(float16_t)
+    add x1, x1, #8      // dst += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #2   // sum += 1 * sizeof(float16_t)
+    add x12, x12, #2   // scale += 1 * sizeof(float16_t)
+    b TILE_1
+b End
+TILE_EQ_1:
+
+    mov x14, x4       // dst_step
+    lsr x15, x4, #1   // src_step = dst_step / 2
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    ld1 {v29.8h}, [x20], #16  // zero
+    ld1 {v30.8h}, [x21], #16  // bias
+    ld1 {v8.h}[0], [x22]  // sums
+    // init
+    dup v14.4s, wzr
+    dup v15.4s, wzr
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    fmla v30.8h, v29.8h, v8.h[0] // bias + zero * sum
+
+
+L2:
+cmp x26, #2
+blt L1
+LoopSz_2:
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x25], #64
+    ld1 {v8.16b}, [x24], #16  // src
+    sub x26, x26, #2
+
+    .inst 0x4e80a50e // smmla v14.4s, v8.16b, v0.16b // (N=0,OC=0) (N=0,OC=1) () ()
+    .inst 0x4e81a50f // smmla v15.4s, v8.16b, v1.16b // (N=0,OC=2) (N=0,OC=3) () ()
+    .inst 0x4e82a510 // smmla v16.4s, v8.16b, v2.16b // (N=0,OC=4) (N=0,OC=5) () ()
+    .inst 0x4e83a511 // smmla v17.4s, v8.16b, v3.16b // (N=0,OC=6) (N=0,OC=7) () ()
+    .inst 0x4e84a512 // smmla v18.4s, v8.16b, v4.16b
+    .inst 0x4e85a513 // smmla v19.4s, v8.16b, v5.16b
+    .inst 0x4e86a514 // smmla v20.4s, v8.16b, v6.16b
+    .inst 0x4e87a515 // smmla v21.4s, v8.16b, v7.16b
+    cmp x26, #2
+    bge LoopSz_2
+L1:
+cmp x26, #1
+blt LoopSzEnd
+LoopSz_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.8b}, [x24], x15   // src
+    .inst 0x4e80a48e // smmla v14.4s, v4.16b, v0.16b
+    .inst 0x4e81a48f // smmla v15.4s, v4.16b, v1.16b
+    .inst 0x4e82a490 // smmla v16.4s, v4.16b, v2.16b
+    .inst 0x4e83a491 // smmla v17.4s, v4.16b, v3.16b
+
+    subs x26, x26, #1
+    bne LoopSz_1
+
+LoopSzEnd:
+    add x18, x18, x13
+    sub x16, x16, #1
+
+    trn1 v26.2d, v14.2d, v15.2d
+    trn1 v27.2d, v16.2d, v17.2d
+    trn2 v28.2d, v18.2d, v19.2d
+    trn2 v29.2d, v20.2d, v21.2d
+    add v26.4s, v26.4s, v28.4s
+    add v27.4s, v27.4s, v29.4s
+    scvtf v26.4s, v26.4s
+    scvtf v27.4s, v27.4s
+    // using float scale dequant for precison
+    ld1 {v4.h}[0], [x23]  // scales
+    ld1 {v0.8h}, [x19], #16  // alpha
+    fcvtl v5.4s, v4.4h
+    fcvtl v20.4s, v0.4h
+    fcvtl2 v21.4s, v0.8h
+
+    fmul v26.4s, v26.4s, v5.s[0]
+    fmul v27.4s, v27.4s, v5.s[0]
+    fmul v26.4s, v26.4s, v20.4s
+    fmul v27.4s, v27.4s, v21.4s
+    fcvtn v17.4h,  v26.4s
+    fcvtn2 v17.8h, v27.4s
+Int8ToFP16:
+    // sum + (zero * sumx) + bias
+    fadd v30.8h, v30.8h, v17.8h
+    st1 {v30.8h}, [x17], x14
+    cmp x16, #1
+    bge LoopDz
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 189 - 0
source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S

@@ -0,0 +1,189 @@
+//
+//  MNNQuantScaleFP16.S
+//  MNN
+//
+//  Created by MNN on 2023/11/01.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+.macro Round z0, z1, z2, z3
+    fcvtas \z0\().8h, \z0\().8h
+    fcvtas \z1\().8h, \z1\().8h
+    fcvtas \z2\().8h, \z2\().8h
+    fcvtas \z3\().8h, \z3\().8h
+.endm
+
+//void MNNQuantScaleFP16(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch)
+asm_function MNNQuantScaleFP16
+
+// x0:absmax, x1:quant_scale, x2:dequant_scale, x3:thread, x4:batch
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+mov w8, #1123942400  // 127.0
+dup v0.4s, w8
+fcvtn v31.4h, v0.4s
+fcvtn2 v31.8h, v0.4s
+lsl x9, x4, #1 // src_step = batch * sizeof(float16_t)
+
+TILE_12:
+cmp x4, #12
+blt TILE_10
+sub x10, x9, #16
+mov x6, x0  // max_ptr
+mov x7, x3  // thread
+
+// absmax: v0, v1
+ld1 {v0.8h}, [x6], #16
+ld1 {v1.d}[0], [x6], x10
+subs x7, x7, #1
+beq Tile12End
+
+LoopSz_12:
+ld1 {v4.8h}, [x6], #16
+ld1 {v5.d}[0], [x6], x10
+
+// absmax = fmax(absmax, absmax[i])
+fmax v0.8h, v0.8h, v4.8h
+fmax v1.8h, v1.8h, v5.8h
+
+subs x7, x7, #1
+bne LoopSz_12
+
+Tile12End:
+sub x4, x4, #12
+add x0, x0, #24
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv v4.8h, v31.8h, v0.8h
+fdiv v5.8h, v31.8h, v1.8h
+fdiv v6.8h, v0.8h, v31.8h
+fdiv v7.8h, v1.8h, v31.8h
+st1 {v4.8h}, [x1], #16
+st1 {v5.d}[0], [x1], #8
+st1 {v6.8h}, [x2], #16
+st1 {v7.d}[0], [x2], #8
+b TILE_12
+
+TILE_10:
+cmp x4, #10
+blt TILE_8
+sub x10, x9, #16
+mov x6, x0  // max_ptr
+mov x7, x3  // thread
+
+// absmax: v0, v1
+ld1 {v0.8h}, [x6], #16
+ld1 {v1.s}[0], [x6], x10
+subs x7, x7, #1
+beq Tile10End
+
+LoopSz_10:
+ld1 {v4.8h}, [x6], #16
+ld1 {v5.s}[0], [x6], x10
+
+// absmax = fmax(absmax, absmax[i])
+fmax v0.8h, v0.8h, v4.8h
+fmax v1.8h, v1.8h, v5.8h
+
+subs x7, x7, #1
+bne LoopSz_10
+
+Tile10End:
+sub x4, x4, #10
+add x0, x0, #20
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv v4.8h, v31.8h, v0.8h
+fdiv v5.8h, v31.8h, v1.8h
+fdiv v6.8h, v0.8h, v31.8h
+fdiv v7.8h, v1.8h, v31.8h
+st1 {v4.8h}, [x1], #16
+st1 {v5.s}[0], [x1], #4
+st1 {v6.8h}, [x2], #16
+st1 {v7.s}[0], [x2], #4
+b TILE_10
+
+
+TILE_8:
+cmp x4, #8
+blt TILE_1
+mov x6, x0  // max_ptr
+mov x7, x3  // thread
+
+// absmax: v0
+ld1 {v0.8h}, [x6], x9
+subs x7, x7, #1
+beq Tile8End
+
+LoopSz_8:
+ld1 {v2.8h}, [x6], x9
+
+// absmax = fmax(absmax, absmax[i])
+fmax v0.8h, v0.8h, v2.8h
+
+subs x7, x7, #1
+bne LoopSz_8
+
+Tile8End:
+sub x4, x4, #8
+add x0, x0, #16
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv v2.8h, v31.8h, v0.8h
+fdiv v3.8h, v0.8h, v31.8h
+st1 {v2.8h}, [x1], #16
+st1 {v3.8h}, [x2], #16
+b TILE_8
+
+
+TILE_1:
+cmp x4, #1
+blt End
+mov x6, x0  // absmax
+mov x7, x3  // thread
+
+// absmax: v0
+ld1 {v0.h}[0], [x6], x9
+subs x7, x7, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v2.h}[0], [x6], x9
+
+// absmax = fmax(absmax, absmax[i])
+fmax h0, h0, h2
+
+subs x7, x7, #1
+bne LoopSz_1
+
+Tile1End:
+sub x4, x4, #1
+add x0, x0, #2
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv h2, h31, h0
+fdiv h3, h0, h31
+st1 {v2.h}[0], [x1], #2
+st1 {v3.h}[0], [x2], #2
+b TILE_1
+
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif
+

+ 106 - 0
source/backend/arm82/asm/arm64/low_memory/MNNQuantSumFP16.S

@@ -0,0 +1,106 @@
+//
+//  MNNQuantSumFP16.S
+//  MNN
+//
+//  Created by MNN on 2023/11/30.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+//void MNNQuantSumFP16(float* sum, const float* dequant_scale, size_t thread, size_t batch)
+asm_function MNNQuantSumFP16
+
+// x0: sum, x1:dequant_scale, x2:thread, x3:batch
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x9, x3, #2 // src_step = batch * sizeof(int32_t)
+mov x10, #0
+
+
+TILE_4:
+cmp x3, #4
+blt TILE_1
+add x6, x0, x10  // sum_ptr
+mov x7, x2  // thread
+
+// sum: v0
+ld1 {v0.4s}, [x6], x9
+subs x7, x7, #1
+beq Tile4End
+
+LoopSz_4:
+ld1 {v1.4s}, [x6], x9
+
+// sum += sum[i]
+add v0.4s, v0.4s, v1.4s
+
+subs x7, x7, #1
+bne LoopSz_4
+
+Tile4End:
+sub x3, x3, #4
+// load dequant_scale
+ld1 {v1.4h}, [x1], #8
+fcvtl v2.4s, v1.4h
+// sum_half = (half)((float)sum_int * dequant_scale)
+scvtf v3.4s, v0.4s
+fmul v4.4s, v3.4s, v2.4s
+fcvtn v5.4h, v4.4s
+st1 {v5.d}[0], [x0], #8
+add x10, x10, #8
+b TILE_4
+
+// x0: sum, x1:dequant_scale, x2:thread, x3:batch
+TILE_1:
+cmp x3, #1
+blt End
+add x6, x0, x10  // sum_ptr
+mov x7, x2  // thread
+
+// sum: v0
+ld1 {v0.s}[0], [x6], x9
+subs x7, x7, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v1.s}[0], [x6], x9
+
+// sum += sum[i]
+// add s0, s0, s1
+add v0.4s, v0.4s, v1.4s
+
+subs x7, x7, #1
+bne LoopSz_1
+
+Tile1End:
+sub x3, x3, #1
+// load dequant_scale
+ld1 {v1.h}[0], [x1], #2
+fcvtl v2.4s, v1.4h
+// sum_half = (half)((float)sum_int * dequant_scale)
+scvtf s3, s0
+fmul s4, s3, s2
+fcvtn v5.4h, v4.4s
+st1 {v5.h}[0], [x0], #2
+add x10, x10, #2
+b TILE_1
+
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif
+

+ 47 - 6
source/backend/coreml/execution/CoreMLConvolution.cpp

@@ -65,11 +65,41 @@ void CoreMLConvolution::addPadLayer(const Tensor * input, const Convolution2DCom
     if (top == 0 && left == 0 && bottom == 0 && right == 0) {
         return;
     }
-    
+    if (isDeconv && outputWidth == inputWidth * common->strideX() && outputHeight == inputHeight * common->strideY()) {
+        isSamePadding = true;
+        return;
+    }
+    if (!isDeconv && outputWidth == UP_DIV(inputWidth, common->strideX()) && outputHeight == UP_DIV(outputHeight, common->strideY())) {
+        isSamePadding = true;
+        return;
+    }
+    if (isDeconv) {
+        int ky = common->kernelY();
+        int kx = common->kernelX();
+        int sy = common->strideY();
+        int sx = common->strideX();
+        int pad_out_height = (outputHeight - ky) / sy + 1;
+        int pad_out_width = (outputWidth - kx) / sx + 1;
+        top = (pad_out_height - inputHeight) / 2;
+        bottom = (pad_out_height - inputHeight) - top;
+        left = (pad_out_width - inputWidth) / 2;
+        right = (pad_out_width - inputWidth) - left;
+        
+        if (top < 0 || bottom < 0 || left < 0 || right < 0) {
+            isSamePadding = true;
+            pad_out_width = outputWidth / sx;
+            pad_out_height = outputHeight / sy;
+            bottom = 0;
+            top = pad_out_height - inputHeight;
+            right = 0;
+            left = pad_out_width - inputWidth;
+        }
+    }
+    std::string layerName = "ConvPadding-" + mConvInputName;
     auto paddingLayer = mCoreMLBackend->create<CoreML__Specification__NeuralNetworkLayer>();
     core_ml__specification__neural_network_layer__init(paddingLayer);
     paddingLayer->layer_case = CORE_ML__SPECIFICATION__NEURAL_NETWORK_LAYER__LAYER_PADDING;
-    mCoreMLBackend->setLayerName(paddingLayer, "ConvPadding");
+    mCoreMLBackend->setLayerName(paddingLayer, layerName.c_str());
     paddingLayer->padding = mCoreMLBackend->create<CoreML__Specification__PaddingLayerParams>();
     core_ml__specification__padding_layer_params__init(paddingLayer->padding);
     paddingLayer->padding->padding_type_case = CORE_ML__SPECIFICATION__PADDING_LAYER_PARAMS__PADDING_TYPE_CONSTANT;
@@ -97,6 +127,10 @@ void CoreMLConvolution::addPadLayer(const Tensor * input, const Convolution2DCom
 ErrorCode CoreMLConvolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     mConvInputName = mCoreMLBackend->getTensorName(inputs[0]);
     mConvOutputName = mCoreMLBackend->getTensorName(outputs[0]);
+    inputWidth = inputs[0]->width();
+    inputHeight = inputs[0]->height();
+    outputWidth = outputs[0]->width();
+    outputHeight = outputs[0]->height();
     loadWeightBias(inputs);
     auto conv2D      = mOp->main_as_Convolution2D();
     auto common      = conv2D->common();
@@ -135,10 +169,17 @@ ErrorCode CoreMLConvolution::onResize(const std::vector<Tensor *> &inputs, const
             break;
         case PadMode_CAFFE:
             addPadLayer(inputs[0], common);
-            mLayer_->convolution->convolution_padding_type_case = CORE_ML__SPECIFICATION__CONVOLUTION_LAYER_PARAMS__CONVOLUTION_PADDING_TYPE_VALID;
-            mLayer_->convolution->valid = mCoreMLBackend->create<CoreML__Specification__ValidPadding>();
-            core_ml__specification__valid_padding__init(mLayer_->convolution->valid);
-            break;
+            if (isSamePadding){
+                mLayer_->convolution->convolution_padding_type_case = CORE_ML__SPECIFICATION__CONVOLUTION_LAYER_PARAMS__CONVOLUTION_PADDING_TYPE_SAME;
+                mLayer_->convolution->same = mCoreMLBackend->create<CoreML__Specification__SamePadding>();
+                core_ml__specification__same_padding__init(mLayer_->convolution->same);
+                break;
+            } else {
+                mLayer_->convolution->convolution_padding_type_case = CORE_ML__SPECIFICATION__CONVOLUTION_LAYER_PARAMS__CONVOLUTION_PADDING_TYPE_VALID;
+                mLayer_->convolution->valid = mCoreMLBackend->create<CoreML__Specification__ValidPadding>();
+                core_ml__specification__valid_padding__init(mLayer_->convolution->valid);
+                break;
+            }
         default:
             break;
     }

+ 2 - 0
source/backend/coreml/execution/CoreMLConvolution.hpp

@@ -28,6 +28,8 @@ private:
     const float *weightPtr, *biasPtr;
     int weightSize, biasSize;
     bool isDeconv = false;
+    bool isSamePadding = false;
+    int outputHeight, outputWidth, inputHeight, inputWidth;
 };
 } // namespace MNN
 

+ 4 - 0
source/backend/cpu/CMakeLists.txt

@@ -2,7 +2,11 @@
 option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF)
 option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF)
 
+if(MNN_SUPPORT_RENDER)
+FILE(GLOB MNN_CPU_SRC ${CMAKE_CURRENT_LIST_DIR}/* ${CMAKE_CURRENT_LIST_DIR}/compute/* ${CMAKE_CURRENT_LIST_DIR}/render/*)
+else()
 FILE(GLOB MNN_CPU_SRC ${CMAKE_CURRENT_LIST_DIR}/* ${CMAKE_CURRENT_LIST_DIR}/compute/*)
+endif()
 add_library(MNNCPU OBJECT ${MNN_CPU_SRC})
 if (MNN_SUPPORT_BF16)
     include(${CMAKE_CURRENT_LIST_DIR}/bf16/CMakeLists.txt)

+ 21 - 13
source/backend/cpu/CPUBackend.cpp

@@ -337,6 +337,26 @@ static OpType _getRealOpType(OpType opType) {
             return opType;
     }
 }
+void* CPUBackend::onMapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* srcTensor) {
+    if (getBytes(this, srcTensor) != srcTensor->getType().bytes()) {
+        return nullptr;
+    }
+    if (OpCommonUtils:: convertDimType(TensorUtils::getDescribe(srcTensor)->dimensionFormat) != dtype) {
+        return nullptr;
+    }
+    return srcTensor->host<void>();
+}
+
+bool CPUBackend::onUnmapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* dstTensor, void* mapPtr) {
+    if (getBytes(this, dstTensor) != dstTensor->getType().bytes()) {
+        return false;
+    }
+    if (OpCommonUtils:: convertDimType(TensorUtils::getDescribe(dstTensor)->dimensionFormat) != dtype) {
+        return false;
+    }
+    return true;
+}
+
 size_t CPUBackend::getTensorSize(const Tensor* tensor, bool multiBytes) const {
     auto core = mCoreFunctions;
     size_t dataSize = 1;
@@ -448,19 +468,7 @@ void CPUBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor)
     }
     std::unique_ptr<Tensor> wrapTensor;
     if (getDataType(srcTensor) != getDataType(dstTensor)) {
-        auto dimType = Tensor::CAFFE;
-        switch (TensorUtils::getDescribe(srcTensor)->dimensionFormat) {
-            case MNN_DATA_FORMAT_NCHW:
-                break;
-            case MNN_DATA_FORMAT_NC4HW4:
-                dimType = Tensor::CAFFE_C4;
-                break;
-            case MNN_DATA_FORMAT_NHWC:
-                dimType = Tensor::TENSORFLOW;
-                break;
-            default:
-                break;
-        }
+        auto dimType =  OpCommonUtils::convertDimType(TensorUtils::getDescribe(srcTensor)->dimensionFormat);
         auto convertType = CPUCastCreator::FlOAT_TO_INT8;
         if (getDataType(srcTensor) == DataType_DT_INT8) {
             convertType = CPUCastCreator::INT8_TO_FlOAT;

+ 9 - 1
source/backend/cpu/CPUBackend.hpp

@@ -33,7 +33,6 @@ public:
     void onConcurrencyEnd() const;
     virtual bool onCheckInfo(Backend::Info& info) const override;
 
-
 private:
     std::shared_ptr<EagerBufferAllocator> mStaticAllocator;
     int mThreadNumber;
@@ -89,6 +88,9 @@ public:
 
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;
+    virtual void* onMapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* srcTensor) override;
+
+    virtual bool onUnmapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* dstTensor, void* mapPtr) override;
     
     virtual void onResizeBegin() override;
     virtual ErrorCode onResizeEnd() override;
@@ -181,6 +183,12 @@ private:
     }
 #endif
 
+#define REGISTER_CPU_OP_CREATOR_RENDER(name, opType)     \
+    void ___##name##__##opType##__() {            \
+        static name _temp;\
+        CPUBackend::addCreator(opType, &_temp); \
+    }
+
 } // namespace MNN
 
 #endif /* CPUBackend_hpp */

+ 1 - 0
source/backend/cpu/CPUBinaryInt8.cpp

@@ -85,6 +85,7 @@ ErrorCode CPUBinaryInt8::onExecute(const std::vector<Tensor*>& inputs, const std
         params.inputZeroPoint = mInputZeros.data();
         params.minValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->min;
         params.maxValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->max;
+
         int start = schedule.first * (int)tId;
         int realSize = schedule.first;
         if (tId == schedule.second -1 ) {

+ 4 - 1
source/backend/cpu/CPUCast.cpp

@@ -189,7 +189,7 @@ Execution *CPUCastCreator::onCreate(const std::vector<Tensor *> &inputs, const s
     if (dstT == MNN::DataType_DT_FLOAT && halide_type_of<int8_t>() == inputDataType) {
         return new CastDataType<int8_t, float>(backend);
     }
-    if (dstT == MNN::DataType_DT_FLOAT && halide_type_t(halide_type_float, 16) == inputDataType) {
+    if (dstT == MNN::DataType_DT_FLOAT && halide_type_t(halide_type_bfloat, 16) == inputDataType) {
         return new BF16ToFP32(backend);
     }
     if (dstT == MNN::DataType_DT_INT8 && halide_type_of<float>() == inputDataType) {
@@ -201,6 +201,9 @@ Execution *CPUCastCreator::onCreate(const std::vector<Tensor *> &inputs, const s
     if (dstT == MNN::DataType_DT_UINT8 && halide_type_of<int32_t>() == inputDataType) {
         return new CastDataType<int32_t, uint8_t>(backend);
     }
+    if (dstT == MNN::DataType_DT_UINT8 && halide_type_of<int8_t>() == inputDataType) {
+        return new CastDataType<int8_t, uint8_t>(backend);
+    }
     if (dstT == MNN::DataType_DT_INT32 && halide_type_of<uint8_t>() == inputDataType) {
         return new CastDataType<uint8_t, int32_t>(backend);
     }

+ 75 - 22
source/backend/cpu/CPUDeconvolution.cpp

@@ -39,12 +39,16 @@ ErrorCode CPUDeconvolutionBasic::onResize(const std::vector<Tensor*>& inputs, co
     return NO_ERROR;
 }
 
-CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* convOp, Backend* b)
+CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* convOp, Backend* b, bool dynamicWeight)
     : CPUDeconvolutionBasic(input, convOp, b) {
     auto conv2D     = convOp->main_as_Convolution2D();
     int outputCount = mCommon->outputCount();
     auto core = static_cast<CPUBackend*>(b)->functions();
+    mDynamicWeight = dynamicWeight;
     mBias.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, core->pack) * core->pack}));
+    if (dynamicWeight) {
+        return;
+    }
     bool success = b->onAcquireBuffer(mBias.get(), Backend::STATIC);
     if (!success) {
         mValid = false;
@@ -78,7 +82,7 @@ CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* co
 }
 
 CPUDeconvolutionCommon::~CPUDeconvolutionCommon() {
-    backend()->onReleaseBuffer(mBias.get(), Backend::STATIC);
+    // Do nothing
 }
 
 // Float Weight.
@@ -137,28 +141,45 @@ static void _reorderWeightInt8(Backend* bn, const Convolution2DCommon* common, c
         }
     }
 }
-CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend)
-    : MNN::CPUDeconvolutionCommon(input, convOp, backend) {
+CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend, bool dynamicWeight)
+    : MNN::CPUDeconvolutionCommon(input, convOp, backend, dynamicWeight) {
     auto core               = static_cast<CPUBackend*>(backend)->functions();
     auto coreInt8           = static_cast<CPUBackend*>(backend)->int8Functions();
     int eP, lP, hP;
     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
     int UNIT, SRC_UNIT, DST_XUNIT;
     coreInt8->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    
+    bool ModeInt8        =  false;
+
     if (CPUBackend::getDataType(input) == DataType_DT_INT8 || input->getType().bytes() == 1) {
         eP = DST_XUNIT;
         lP = SRC_UNIT;
         hP = UNIT;
+        ModeInt8 = true;
     }
     auto conv2d                  = convOp->main_as_Convolution2D();
     auto layer                   = conv2d->common();
     int outputCount              = layer->outputCount();
     const auto outputChannleUp4  = UP_DIV(outputCount, hP) * hP;
+    int fw                  = layer->kernelX();
+    int fh                  = layer->kernelY();
+    int srcCount            = mSrcCount;
+    mParam.fh = fh;
+    mParam.fw = fw;
+    mParam.srcCount = srcCount;
+    mParam.outputCount = outputCount;
+    auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
+    mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
+    std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount}));
+    if (dynamicWeight) {
+        mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, ModeInt8));
+        mWeightTransformCache = cache;
+        return;
+    }
+
     const float* tempWeight      = nullptr;
     const int8_t* quanWeightInt8 = nullptr;
 
-    bool ModeInt8        =  false;
     int tempWeightSize   = 0;
     std::unique_ptr<Tensor> externalWeightTensor;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
@@ -180,22 +201,15 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
         OpCommonUtils::loadExternalData(backend, externalWeightTensor->host<char>(), conv2d->external()->Get(0), bytes);
         tempWeight = externalWeightTensor->host<float>();
     } else {
-        if (CPUBackend::getDataType(input) == DataType_DT_INT8 || input->getType().bytes() == 1) {
+        if (ModeInt8) {
             ConvolutionCommon::getConvInt8Parameters(conv2d, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr);
-            ModeInt8 = true;
         } else {
             ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2d, &tempWeight, &tempWeightSize);
         }
     }
-
-    int fw                  = layer->kernelX();
-    int fh                  = layer->kernelY();
-    int srcCount            = mSrcCount;
-    
-    auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
     
-    std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount}));
-    bool success =  backend->onAcquireBuffer(cache.get(), Backend::STATIC);
+    bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) &&
+                   backend->onAcquireBuffer(cache.get(), Backend::STATIC);
     if (!success) {
         mValid = false;
         return;
@@ -233,7 +247,45 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
 }
 
 CPUDeconvolution::~CPUDeconvolution() {
-    backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC);
+    // Do nothing
+}
+ErrorCode CPUDeconvolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    if (mDynamicWeight) {
+        auto core = static_cast<CPUBackend*>(backend())->functions();
+        _transformWeight(inputs[1]->host<uint8_t>(), mWeight->host<uint8_t>(), mParam.outputCount, mParam.srcCount, mParam.fh, mParam.fw, mWeightTransformCache->host<uint8_t>(), core);
+        ::memset(mBias->host<uint8_t>(), 0, mBias->length(0) * core->bytes);
+        if (inputs.size() >= 3) {
+            ::memcpy(mBias->host<uint8_t>(), inputs[2]->host<uint8_t>(), TensorUtils::getRawSize(inputs[2]) * core->bytes);
+        }
+    }
+    return mOrigin->onExecute(mTempInputs, outputs);
+}
+ErrorCode CPUDeconvolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    if (mDynamicWeight) {
+        bool res = backend()->onAcquireBuffer(mWeight.get(), Backend::DYNAMIC);
+        if (!res) {
+            return OUT_OF_MEMORY;
+        }
+        res = backend()->onAcquireBuffer(mWeightTransformCache.get(), Backend::DYNAMIC);
+        if (!res) {
+            return OUT_OF_MEMORY;
+        }
+        res = backend()->onAcquireBuffer(mBias.get(), Backend::DYNAMIC);
+        if (!res) {
+            return OUT_OF_MEMORY;
+        }
+    }
+    mTempInputs = {inputs[0], mWeight.get(), mBias.get()};
+    auto code = mOrigin->onResize(mTempInputs, outputs);
+    if (NO_ERROR != code) {
+        return code;
+    }
+    if (mDynamicWeight) {
+        backend()->onReleaseBuffer(mWeight.get(), Backend::DYNAMIC);
+        backend()->onReleaseBuffer(mWeightTransformCache.get(), Backend::DYNAMIC);
+        backend()->onReleaseBuffer(mBias.get(), Backend::DYNAMIC);
+    }
+    return NO_ERROR;
 }
 
 
@@ -274,8 +326,7 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, c
     auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
     //int zeroPoint = 0;
 
-    auto biasPtr      = inputs[2]->host<float>();
-    auto inputPtr  = input->host<float>();
+    auto biasTensor = inputs[2];
     
     // prepare for float2int8 if necessary.
     auto outputQuant = TensorUtils::getQuantInfo(outputs[0]);
@@ -323,9 +374,11 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, c
     }
 
     mPostFunctions.emplace_back(std::make_pair([ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY,
-                       strideX, threadNumber, src_width, src_height, plane, biasPtr, this, core, gcore, batch, outi8, scales,
+                       strideX, threadNumber, src_width, src_height, plane, input, biasTensor, this, core, gcore, batch, outi8, scales,
                        minValue, maxValue, zeroPoint, outputFp32Ptr](uint8_t* outputPtr, int tId) {
         auto colBufferPtr = mTempOutput->host<uint8_t>();
+        auto biasPtr      = biasTensor->host<float>();
+        auto inputPtr  = input->host<float>();
         auto unitBytes = core->pack * core->bytes;
         auto tempOutPtr = outputPtr;
         auto float2Int8_step = src_height * src_width * batch;
@@ -409,7 +462,7 @@ public:
                                 const MNN::Op* op, Backend* backend) const {
         auto convOp = op->main_as_Convolution2D();
         auto common = convOp->common();
-        if (backend->type() == MNN_FORWARD_CPU) {
+        if (backend->type() == MNN_FORWARD_CPU && inputs.size() == 1) {
             if (common->strideY() > 1 || common->strideX() > 1) {
                 if (common->dilateX() == 1 && common->dilateY() == 1) {
                     if (common->kernelX() / common->strideX() > 2 || common->kernelY() / common->strideY() > 2) {
@@ -418,7 +471,7 @@ public:
                 }
             }
         }
-        return new CPUDeconvolution(inputs[0], op, backend);
+        return new CPUDeconvolution(inputs[0], op, backend, inputs.size() > 1);
     }
 };
 

+ 13 - 10
source/backend/cpu/CPUDeconvolution.hpp

@@ -28,11 +28,12 @@ protected:
 
 class CPUDeconvolutionCommon : public CPUDeconvolutionBasic {
 public:
-    CPUDeconvolutionCommon(const Tensor *input, const Op *convOp, Backend *b);
+    CPUDeconvolutionCommon(const Tensor *input, const Op *convOp, Backend *b, bool dynamicWeight);
     virtual ~CPUDeconvolutionCommon();
 
 protected:
     std::shared_ptr<Tensor> mBias;
+    bool mDynamicWeight;
 };
 
 class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic {
@@ -97,19 +98,21 @@ private:
 
 class CPUDeconvolution : public CPUDeconvolutionCommon {
 public:
-    CPUDeconvolution(const Tensor *input, const Op *convOp, Backend *b);
+    CPUDeconvolution(const Tensor *input, const Op *convOp, Backend *b, bool dynamicWeight);
     virtual ~CPUDeconvolution();
-    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
-        mOrigin->onExecute(mTempInputs, outputs);
-        return NO_ERROR;
-    }
-    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
-        mTempInputs = {inputs[0], mWeight.get(), mBias.get()};
-        return mOrigin->onResize(mTempInputs, outputs);
-    }
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
 
+    struct Param {
+        int outputCount;
+        int srcCount;
+        int fh;
+        int fw;
+    };
 private:
+    Param mParam;
     std::shared_ptr<Tensor> mWeight;
+    std::shared_ptr<Tensor> mWeightTransformCache;
     std::vector<Tensor *> mTempInputs;
     std::shared_ptr<CPUDeconvolutionOrigin> mOrigin;
 };

+ 1 - 1
source/backend/cpu/CPUDeconvolutionDepthwise.cpp

@@ -16,7 +16,7 @@
 
 namespace MNN {
 CPUDeconvolutionDepthwise::CPUDeconvolutionDepthwise(const Tensor* input, const Op* convOp, Backend* b)
-    : MNN::CPUDeconvolutionCommon(input, convOp, b) {
+    : MNN::CPUDeconvolutionCommon(input, convOp, b, false) {
     auto conv               = convOp->main_as_Convolution2D();
     auto layer              = convOp->main_as_Convolution2D()->common();
     int kw                  = layer->kernelX();

+ 87 - 0
source/backend/cpu/CPUDequantizeLinear.cpp

@@ -0,0 +1,87 @@
+//
+//  CPUDequantizeLinear.cpp
+//  MNN
+//
+//  Created by MNN on 2018/07/15.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+#include "backend/cpu/CPUBackend.hpp"
+#include "core/Concurrency.h"
+#include "backend/cpu/CPUDequantizeLinear.hpp"
+#include "core/TensorUtils.hpp"
+#include "compute/CommonOptFunction.h"
+
+namespace MNN {
+
+CPUDequantizeLinear::CPUDequantizeLinear(Backend *b, float* scale, int8_t* zeroPoints, int size, int axis, int inputBits) : MNN::Execution(b){
+    mSize = size;
+    mAxis = axis;
+    mInputBits = inputBits;
+}
+ErrorCode CPUDequantizeLinear::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    if (mInputBits == 8) {
+        mFunc = dequantizeFunc<int8_t>;
+    } else if (mInputBits == 16) {
+        mFunc = dequantizeFunc<int16_t>;
+    } else {
+        mFunc = dequantizeFunc<int32_t>;
+    }
+    float *scale = inputs[1]->host<float>();
+    int8_t *zero = nullptr;
+    if (inputs.size() > 2) {
+        zero = inputs[2]->host<int8_t>();;
+    }
+    if (mSize == 1) {
+        mQuantScales.resize(4, *scale);
+        if (nullptr != zero) {
+            mQuantZeroPoints.resize(4, *zero);
+        } else {
+            mQuantZeroPoints.resize(4, 0);
+        }
+    } else {
+        mQuantScales.resize(mSize);
+        ::memcpy(mQuantScales.data(), scale, sizeof(float) * mSize);
+        if (nullptr != zero) {
+            mQuantZeroPoints.resize(mSize);
+            ::memcpy(mQuantZeroPoints.data(), zero, mSize);
+        } else {
+            mQuantZeroPoints.resize(mSize);
+        }
+    }
+    return NO_ERROR;
+}
+ErrorCode CPUDequantizeLinear::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs)  {
+    auto input = inputs[0];
+    int N = input->length(0);
+    ssize_t size = N;
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+    int UNIT, SRC_UNIT, DST_XUNIT;
+    core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
+    auto dst = outputs[0]->host<float>();
+    auto src = input->host<int8_t>();
+    mFunc(dst, src, input->dimensions(), input->size(), mSize, UNIT, mQuantScales.data(), mQuantZeroPoints.data(), core);
+    return NO_ERROR;
+}
+
+class CPUDequantizeLinearCreator : public CPUBackend::Creator {
+public:
+    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
+                                const MNN::Op *op, Backend *backend) const override {
+        auto dataType = inputs[0]->getType();
+        if (dataType.bits != 8 && dataType.bits != 16 && dataType.bits != 32) {
+            MNN_ERROR("Input of Dequantize must be int8/uint8/fp16/int32\n");
+            return nullptr;
+        }
+        int inputBits = dataType.bits;
+        int size = op->main_as_DequantizeLinear()->scaleSize();
+        int axis = op->main_as_DequantizeLinear()->scaleAxis();
+        if (inputs.size() > 2) {
+            return new CPUDequantizeLinear(backend, inputs[1]->host<float>(), inputs[2]->host<int8_t>(), size, axis, inputBits);
+        }
+        return new CPUDequantizeLinear(backend, inputs[1]->host<float>(), nullptr, size, axis, inputBits);
+    }
+};
+
+REGISTER_CPU_OP_CREATOR(CPUDequantizeLinearCreator, OpType_DequantizeLinear);
+
+} // namespace MNN

+ 81 - 0
source/backend/cpu/CPUDequantizeLinear.hpp

@@ -0,0 +1,81 @@
+//
+//  CPUDequantizeLinear.hpp
+//  MNN
+//
+//  Created by MNN on 2018/07/15.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifndef CPUDequantizeLinear_hpp
+#define CPUDequantizeLinear_hpp
+
+#include "core/AutoStorage.h"
+#include "core/Execution.hpp"
+#include "compute/Int8FunctionsOpt.h"
+
+namespace MNN {
+typedef void(*dequantFunc)(float* dst, const int8_t* source, int inputDim, int inputSize, int size, int UNIT, float* scales, int8_t* zeros, const CoreInt8Functions* core);
+class CPUDequantizeLinear : public Execution {
+public:
+    CPUDequantizeLinear(Backend *b, float* scales, int8_t* zeroPoints, int size = 1, int axis = 0, int inputBits = 8);
+    virtual ~CPUDequantizeLinear() = default;
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+private:
+    std::vector<float> mQuantScales;
+    std::vector<int8_t> mQuantZeroPoints;
+    int mSize = 1;
+    int mAxis = 0;
+    int mInputBits = 8;
+    dequantFunc mFunc;
+};
+
+template<typename T>
+void dequantizeFunc(float* dst, const int8_t* source, int inputDim, int inputSize, int size, int UNIT, float* scales, int8_t* zeros, const CoreInt8Functions* core) {
+#ifdef MNN_USE_SSE
+    auto src = (uint8_t*)source;
+    int offset = 128;
+#else
+    auto src = (int8_t*)source;
+    int offset = 0;
+#endif
+//    auto src = (T*)source;
+    if (inputDim == 1) {
+        for (int i = 0; i < size; ++i) {
+            dst[i] = static_cast<float>(src[i] - zeros[i] - offset) * scales[i];
+        }
+        return;
+    }
+    int chw = 1;
+    if (inputDim > 1) {
+        chw = inputSize / (size * sizeof(T));
+    }
+
+    if (size == 1) {
+        if (sizeof(T) == 1) {
+            core->MNNInt8ScaleToFloat(dst, (int8_t*)src, scales, chw / UNIT, zeros[0]);
+            int sizeDiv = (int)chw / UNIT;
+            for (int k = sizeDiv * UNIT; k < chw; ++k) {
+                dst[k] = static_cast<float>(src[k] - zeros[0] - offset) * scales[0];
+            }
+        } else {
+            for (int k = 0; k < chw; ++k) {
+                dst[k] = static_cast<float>(src[k] - zeros[0] - offset) * scales[0];
+            }
+        }
+        
+    } else {
+        for (int i = 0; i < size; ++i) {
+            std::vector<float> tmp(4, scales[i]);
+            //core->MNNInt8ScaleToFloat(dst, src, tmp.data(), sizeDiv, mQuantZeroPoints[i]);
+            for (int k = 0; k < chw; ++k) {
+                dst[k] = static_cast<float>(src[k] - zeros[i] - offset) * scales[i];
+            }
+            src += chw;
+            dst += chw;
+        }
+    }
+}
+} // namespace MNN
+
+#endif /* CPUDequantizeLinear_hpp */

+ 77 - 0
source/backend/cpu/CPUGridSample.cpp

@@ -128,6 +128,80 @@ ErrorCode CPUGridSample::onExecute(const std::vector<Tensor *> &inputs, const st
     return NO_ERROR;
 }
 
+class CPUGridSampleGrad : public CPUGridSample {
+public:
+    CPUGridSampleGrad(Backend *b, SampleMode mode, BorderMode paddingMode, bool alignCorners) : CPUGridSample(b, mode, paddingMode, alignCorners) {
+        // Do nothing
+    }
+
+    virtual ~CPUGridSampleGrad() = default;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
+        int numberThread = static_cast<CPUBackend*>(backend())->threadNumber();
+        auto core = static_cast<CPUBackend*>(backend())->functions();
+        auto outputTensor = inputs[0];
+        int outD, outH, outW;
+        if (outputTensor->dimensions() == 4) {
+            outH = outputTensor->buffer().dim[2].extent;
+            outW = outputTensor->buffer().dim[3].extent;
+            mTempCordBuffer.reset(Tensor::createDevice<uint8_t>({1, outH * outW * 2 * core->bytes}));
+        } else {
+            outD = outputTensor->buffer().dim[2].extent;
+            outH = outputTensor->buffer().dim[3].extent;
+            outW = outputTensor->buffer().dim[4].extent;
+            mTempCordBuffer.reset(Tensor::createDevice<uint8_t>({1, outD * outH * outW * 3 * core->bytes}));
+        }
+        auto res = backend()->onAcquireBuffer(mTempCordBuffer.get(), Backend::DYNAMIC);
+        if (!res) {
+            return OUT_OF_MEMORY;
+        }
+        backend()->onReleaseBuffer(mTempCordBuffer.get(), Backend::DYNAMIC);
+        return NO_ERROR;
+    }
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
+        auto inputTensor = outputs[0];
+        ::memset(inputTensor->host<uint8_t>(), 0, static_cast<CPUBackend*>(backend())->getTensorSize(inputTensor, false) * static_cast<CPUBackend*>(backend())->functions()->bytes);
+        auto gridTensor = inputs[1];
+        auto outputTensor = inputs[0];
+        auto inputPtr = inputTensor->host<uint8_t>();
+        auto gridPtr = gridTensor->host<uint8_t>();
+        auto outputPtr = outputTensor->host<uint8_t>();
+        auto core = static_cast<CPUBackend*>(backend())->functions();
+        auto batches = inputTensor->buffer().dim[0].extent;
+        auto channels = inputTensor->buffer().dim[1].extent;
+        auto channelC4 = UP_DIV(channels, core->pack);
+        if (outputTensor->dimensions() != 4) {
+            return NOT_SUPPORT;
+        }
+        auto inH = inputTensor->buffer().dim[2].extent;
+        auto inW = inputTensor->buffer().dim[3].extent;
+        auto outH = outputTensor->buffer().dim[2].extent;
+        auto outW = outputTensor->buffer().dim[3].extent;
+        auto threadCount = static_cast<CPUBackend*>(backend())->threadNumber();
+        auto tileCount = outH;
+        auto inOffset  = batches * inH * inW * core->pack;
+        auto outOffset = batches * outH * outW * core->pack;
+        auto cordPtr = mTempCordBuffer->host<uint8_t>();
+        for (auto b = 0; b < batches; ++b) {
+            auto _inputPtr = inputPtr + b * inH * inW * core->pack * core->bytes;
+            auto _gridPtr = gridPtr + b * gridTensor->buffer().dim[0].stride * core->bytes;
+            auto _outputPtr = outputPtr + b * outH * outW * core->pack * core->bytes;
+            core->MNNGridSampleComputeCord((float *)cordPtr, (const float *)_gridPtr, inH, inW, outH, outW, gridTensor->buffer().dim[1].stride, mAlignCorners);
+            // Compute cord
+            for (int index=0; index < tileCount; index++) {
+                auto c = index / outH;
+                auto h = index % outH;
+                auto inputC = _inputPtr + c * inW * inH * batches * core->pack * core->bytes;
+                auto outputC = _outputPtr + c * outW * outH * batches * core->pack * core->bytes;
+                auto cordH = cordPtr + h * outW * 2 * core->bytes;
+                auto outputH = outputC + h * outW * core->pack * core->bytes;
+                core->MNNGridSampleInterpGrad((float *)outputH, (float *)inputC, (const float *)cordH, inH, inW, outW, channelC4, inOffset, outOffset, (mMode == SampleMode_NEAREST), (mPaddingMode == BorderMode_ZEROS));
+            }
+        }
+
+        return NO_ERROR;
+    }
+};
+
 class CPUGridSampleCreator : public CPUBackend::Creator {
 public:
     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
@@ -141,6 +215,9 @@ public:
             MNN_ERROR("Don't has function for CPUGridSample\n");
             return nullptr;
         }
+        if (gridSampleParam->backward()) {
+            return new CPUGridSampleGrad(backend, mode, paddingMode, alignCorners);;
+        }
         return new CPUGridSample(backend, mode, paddingMode, alignCorners);
     }
 };

+ 1 - 1
source/backend/cpu/CPUGridSample.hpp

@@ -20,7 +20,7 @@ public:
     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
 
-private:
+protected:
     SampleMode mMode;
     BorderMode mPaddingMode;
     bool mAlignCorners;

+ 0 - 10
source/backend/cpu/CPUImageProcess.cpp

@@ -11,16 +11,6 @@
 #include <string.h>
 #include <mutex>
 #include "core/Macro.h"
-#ifdef MNN_USE_NEON
-#include <arm_neon.h>
-#endif
-#ifdef MNN_USE_SSE
-#if defined(_MSC_VER)
-#include <intrin.h>
-#else
-#include <x86intrin.h>
-#endif
-#endif
 #include <map>
 #include <utility>
 

+ 4 - 12
source/backend/cpu/CPUMatMul.cpp

@@ -15,6 +15,7 @@
 #include "core/Concurrency.h"
 #include "core/BufferAllocator.hpp"
 #include "core/TensorUtils.hpp"
+#include "core/OpCommonUtils.hpp"
 #include "math/Vec.hpp"
 
 
@@ -63,21 +64,12 @@ ErrorCode CPUMatMul::onResize(const std::vector<Tensor*>& inputs, const std::vec
     const Tensor* A = inputs[0];
     const Tensor* B = inputs[1];
     Tensor* C       = outputs[0];
-    auto w0         = inputs[0]->length(1);
-    auto h0         = inputs[0]->length(0);
     auto core = static_cast<CPUBackend*>(backend())->functions();
     mPreFunctions.clear();
     mPostFunctions.clear();
-    auto e = A->length(0);
-    auto h = B->length(1);
-    auto l = A->length(1);
-    if (mTransposeA) {
-        l = A->length(0);
-        e = A->length(1);
-    }
-    if (mTransposeB) {
-        h = B->length(0);
-    }
+    int e, l, h;
+    OpCommonUtils::computeMatMulSize(mTransposeA, mTransposeB, A, B, e, l, h);
+
     // If encoded but resized as h=1/e=1, the computer should clear firstly
     mComputer->onReset();
     if (h == 1) {

+ 15 - 0
source/backend/cpu/CPUOPRegister.cpp

@@ -66,7 +66,14 @@ extern void ___CPUSetDiff1DCreator__OpType_SetDiff1D__();
 extern void ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__();
 extern void ___CPUSvdCreator__OpType_Svd__();
 extern void ___CPULayerNormCreator__OpType_LayerNorm__();
+extern void ___CPUQuantizeLinearCreator__OpType_QuantizeLinear__();
+extern void ___CPUDequantizeLinearCreator__OpType_DequantizeLinear__();
 
+#ifdef MNN_SUPPORT_RENDER
+extern void ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
+extern void ___CPURasterDiffCreator__OpType_RasterDiff__();
+extern void ___CPUTextureCreator__OpType_Texture__();
+#endif
 void registerCPUOps() {
 ___CPUCropAndResizeCreator__OpType_CropAndResize__();
 ___CPUArgMaxCreator__OpType_ArgMax__();
@@ -134,5 +141,13 @@ ___CPUSetDiff1DCreator__OpType_SetDiff1D__();
 ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__();
 ___CPUSvdCreator__OpType_Svd__();
 ___CPULayerNormCreator__OpType_LayerNorm__();
+#ifdef MNN_SUPPORT_RENDER
+___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
+___CPURasterDiffCreator__OpType_RasterDiff__();
+___CPUTextureCreator__OpType_Texture__();
+#endif
+___CPUQuantizeLinearCreator__OpType_QuantizeLinear__();
+___CPUDequantizeLinearCreator__OpType_DequantizeLinear__();
+//CPUQuantizeLinearCreator
 }
 }

+ 85 - 0
source/backend/cpu/CPUQuantizeLinear.cpp

@@ -0,0 +1,85 @@
+//
+//  CPUQuantizeLinear.cpp
+//  MNN
+//
+//  Created by MNN on 2018/07/15.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+#include "backend/cpu/CPUBackend.hpp"
+#include "core/Concurrency.h"
+#include "backend/cpu/CPUQuantizeLinear.hpp"
+#include "compute/CommonOptFunction.h"
+#include "core/TensorUtils.hpp"
+
+namespace MNN {
+
+CPUQuantizeLinear::CPUQuantizeLinear(Backend *b, int size, int axis) : MNN::Execution(b){
+    mSize = size;
+    mAxis = axis;
+}
+
+ErrorCode CPUQuantizeLinear::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    int size = mSize;
+    float* scale = inputs[1]->host<float>();
+    int8_t* zero = nullptr;
+    if (inputs.size() > 2) {
+        zero = inputs[2]->host<int8_t>();
+    }
+    if (mSize == 1) {
+        float s = scale[0] == 0?0: 1/ scale[0];
+        mQuantScales.resize(4, s);
+        if (nullptr != zero) {
+            int8_t z = *zero;
+            mQuantZeroPoints.resize(4, z);
+        } else {
+            mQuantZeroPoints.resize(4);
+        }
+    } else { // TODO scale: (1,D)
+        
+    }
+    return NO_ERROR;
+}
+ErrorCode CPUQuantizeLinear::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs)  {
+    auto input = inputs[0];
+    int N = input->length(0), C = input->length(1), H = input->length(2), W = input->length(3);
+    ssize_t size = N * C * H * W;
+    auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+    int UNIT, SRC_UNIT, DST_XUNIT;
+    core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
+    int maxValue = 127;
+    int minValue = -128;
+#ifdef MNN_USE_SSE
+    auto dst = outputs[0]->host<uint8_t>();
+    int offset = 128;
+#else
+    auto dst = outputs[0]->host<int8_t>();
+    int offset = 0;
+#endif
+    if (mSize == 1) {
+        auto src = input->host<float>();
+        int sizeDiv = (int)size / UNIT;
+        core->MNNFloat2Int8(src, (int8_t*)dst, size / UNIT, mQuantScales.data(), -128, 127, mQuantZeroPoints[0]);
+        for (int i = sizeDiv * UNIT; i < size; ++i) {
+            int v = (int)roundf(src[i] * mQuantScales[0]) + mQuantZeroPoints[0] + offset;
+            v = std::max(minValue + offset, std::min(maxValue + offset, v));
+            dst[i] = v;
+        }
+    } else {
+        
+    }
+        return NO_ERROR;
+}
+
+class CPUQuantizeLinearCreator : public CPUBackend::Creator {
+public:
+    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
+                                const MNN::Op *op, Backend *backend) const override {
+        int size = op->main_as_QuantizeLinear()->scaleSize();
+        int axis = op->main_as_QuantizeLinear()->scaleAxis();
+        return new CPUQuantizeLinear(backend, size, axis);
+    }
+};
+
+REGISTER_CPU_OP_CREATOR(CPUQuantizeLinearCreator, OpType_QuantizeLinear);
+
+} // namespace MNN

+ 31 - 0
source/backend/cpu/CPUQuantizeLinear.hpp

@@ -0,0 +1,31 @@
+//
+//  CPUQuantizeLinear.hpp
+//  MNN
+//
+//  Created by MNN on 2018/07/15.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifndef CPUQuantizeLinear_hpp
+#define CPUQuantizeLinear_hpp
+
+#include "core/AutoStorage.h"
+#include "core/Execution.hpp"
+
+namespace MNN {
+class CPUQuantizeLinear : public Execution {
+public:
+    CPUQuantizeLinear(Backend *b, int size = 1, int axis = 0);
+    virtual ~CPUQuantizeLinear() = default;
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+private:
+    std::vector<float> mQuantScales;
+    std::vector<int8_t> mQuantZeroPoints;
+    int mSize = 1;
+    int mAxis = 0;
+};
+
+} // namespace MNN
+
+#endif /* CPUQuantizeLinear_hpp */

+ 23 - 9
source/backend/cpu/CPURaster.cpp

@@ -235,7 +235,7 @@ ErrorCode CPURaster::onResize(const std::vector<Tensor *> &____inputs, const std
     }
     return NO_ERROR;
 }
-static void _transpose4Bit(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region) {
+static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) {
     int dims[4], keepDim = -1;
     for (int i = 0; i < 3; i++) {
         if (region.src.stride[i] == 1 && region.size[i] != 1) {
@@ -248,10 +248,23 @@ static void _transpose4Bit(int32_t* dstO, const int32_t* srcO, const Tensor::Ins
             keepDim = i;
         }
     }
-    for (int z=0; z<region.size[keepDim]; ++z) {
-        auto srcZ = srcO + region.src.stride[keepDim] * z;
-        auto dstZ = dstO + region.dst.stride[keepDim] * z;
-        MNNTranspose32Bit(dstZ, srcZ, dims);
+    if (bytes == 4) {
+        for (int z=0; z<region.size[keepDim]; ++z) {
+            auto srcZ = srcO + region.src.stride[keepDim] * z;
+            auto dstZ = dstO + region.dst.stride[keepDim] * z;
+            MNNTranspose32Bit(dstZ, srcZ, dims);
+        }
+        return;
+    }
+    if (bytes == 2) {
+        auto srcH = reinterpret_cast<const int16_t*>(srcO);
+        auto dstH = reinterpret_cast<int16_t*>(dstO);
+        for (int z = 0; z < region.size[keepDim]; ++z) {
+            auto srcZ = srcH + region.src.stride[keepDim] * z;
+            auto dstZ = dstH + region.dst.stride[keepDim] * z;
+            MNNTranspose16Bit(dstZ, srcZ, dims);
+        }
+        return;
     }
 }
 typedef void (*BlitProc)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds);
@@ -497,8 +510,9 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const
         return;
     }
     int srcOne, dstOne;
-    if (OpCommonUtils::isTranspose(slice, srcOne, dstOne) && 4 == bytes) {
-        _transpose4Bit((int32_t*)dstPtr, (const int32_t*)srcPtr, slice);
+    if (OpCommonUtils::isTranspose(slice, srcOne, dstOne) && (4 == bytes || 2 == bytes)) {
+    // if (OpCommonUtils::isTranspose(slice, srcOne, dstOne) && 4 == bytes) {
+        _transpose((int32_t*)dstPtr, (const int32_t*)srcPtr, slice, bytes);
         return;
     }
     if (1 == slice.src.stride[2] && 1 == slice.dst.stride[2]) {
@@ -789,7 +803,7 @@ public:
                     auto inputSize = input->elementSize();
                     auto output = mStack[cmd->indexes()->data()[0]];
                     auto bytes = input->getType().bytes();
-                    if (halide_type_float == input->getType().code && bytes == 4) {
+                    if (halide_type_float == input->getType().code) {
                         bytes = cpubackend->functions()->bytes;
                     }
                     _blit(reg, bytes, input->host<uint8_t>(), output->host<uint8_t>());
@@ -827,7 +841,7 @@ public:
                 auto inputSize = input->elementSize();
                 auto output = mStack[cmd->indexes()->data()[0]];
                 auto bytes = input->getType().bytes();
-                if (halide_type_float == input->getType().code && bytes == 4) {
+                if (halide_type_float == input->getType().code) {
                     bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
                 }
                 auto step0 = cmd->steps()->data()[0];

+ 2 - 2
source/backend/cpu/CPUUnary.cpp

@@ -226,8 +226,8 @@ static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParam
 #ifdef MNN_USE_NEON
     int8_t* outPtr = (int8_t*)out;
     int8_t* inPtr  = (int8_t*)inp;
-    int8x16_t one = vdupq_n_s8(1);
-    int8x16_t negone = vdupq_n_s8(-1);
+    int16x8_t one = vdupq_n_s8(1);
+    int16x8_t negone = vdupq_n_s8(-1);
     int16x8_t zero = vdupq_n_s16(0);
     int8x8_t inZeroPoint = vdup_n_s8(params->inputZeroPoint[0]);
     int8x8_t outZeroPoint = vdup_n_s8(params->outputZeroPoint[0]);

+ 124 - 0
source/backend/cpu/GridSampler.hpp

@@ -0,0 +1,124 @@
+static int MNNGridSampleComputeOffset(int h, int w, int height, int width, bool padMode) {
+    if (padMode == true) { //padMode == BorderMode_ZEROS
+        if (h < 0 || h >= height || w < 0 || w >= width) {
+            return -1;
+        }
+    } else {
+        // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
+        // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
+        // the leftover reflections degrade to GridSamplePaddingMode_BORDER
+        h = h < 0 ? 0 : (h > (height - 1) ? (height - 1) : h);
+        w = w < 0 ? 0 : (w > (width - 1) ? (width - 1) : w);
+    }
+    return h * width * PACK + w * PACK;
+}
+
+static void MNNGridSampleInterp(FLOAT* outputPtr, const FLOAT* inputPtr, const FLOAT* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) {
+    for (auto ow = 0; ow < outW; ++ow) {
+        auto w_ = cordPtr[2 * ow + 0];
+        auto h_ = cordPtr[2 * ow + 1];
+        float w = (float)(w_);
+        float h = (float)(h_);
+        Vec interp;
+
+        if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
+            int nh = ::floor(h + 0.5f);
+            int nw = ::floor(w + 0.5f);
+            int ns = MNNGridSampleComputeOffset(nh, nw, inH, inW, padMode);
+            for (int k = 0; k < channelCUnit; ++k) {
+                interp = ns == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + ns);
+                Vec::save(outputPtr + k * outOffset + PACK * ow, interp);
+            }
+        } else { //sampleMode == GridSampleMode_BILINEAR
+            int w0_h = ::floor(h);
+            int w0_w = ::floor(w);
+            int w1_h = ::ceil(h);
+            int w1_w = ::ceil(w);
+            auto oneV = Vec(1.0f);
+
+            auto f0 = Vec((FLOAT)w1_w - w_);
+            auto f1 = oneV - f0;
+            auto h0 = Vec((FLOAT)w1_h - h_);
+            auto h1 = oneV - h0;
+
+            int s00 = MNNGridSampleComputeOffset(w0_h, w0_w, inH, inW, padMode);
+            int s01 = MNNGridSampleComputeOffset(w0_h, w1_w, inH, inW, padMode);
+            int s10 = MNNGridSampleComputeOffset(w1_h, w0_w, inH, inW, padMode);
+            int s11 = MNNGridSampleComputeOffset(w1_h, w1_w, inH, inW, padMode);
+
+            for (int k = 0; k < channelCUnit; ++k) {
+                Vec i00 = s00 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s00);
+                Vec i01 = s01 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s01);
+                Vec i10 = s10 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s10);
+                Vec i11 = s11 == -1 ? Vec(0.f) : Vec::load(inputPtr + k * inOffset + s11);
+
+                Vec i0 = i00 * f0 + i01 * f1;
+                Vec i1 = i10 * f0 + i11 * f1;
+
+                interp = i0 * h0 + i1 * h1;
+                Vec::save(outputPtr + k * outOffset + PACK * ow, interp);
+            }
+        }
+    }
+}
+static void MNNGridSampleInterpGrad(FLOAT* outputPtr, FLOAT* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) {
+    const int pack = PACK;
+    for (auto ow = 0; ow < outW; ++ow) {
+        auto w = cordPtr[2 * ow + 0];
+        auto h = cordPtr[2 * ow + 1];
+        Vec interp;
+
+        if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
+            int nh = ::floor(h + 0.5f);
+            int nw = ::floor(w + 0.5f);
+            int ns = MNNGridSampleComputeOffset(nh, nw, inH, inW, padMode);
+            if (ns != -1) {
+                for (int k = 0; k < channelCUnit; ++k) {
+                    auto o = Vec::load(outputPtr + k * outOffset + pack * ow);
+                    auto i = Vec::load(inputPtr + k * inOffset + ns);
+                    Vec::save(inputPtr + k * inOffset + ns, i + o);
+                }
+            }
+        } else { //sampleMode == GridSampleMode_BILINEAR
+            int w0_h = ::floor(h);
+            int w0_w = ::floor(w);
+            int w1_h = ::ceil(h);
+            int w1_w = ::ceil(w);
+            auto oneV = Vec(1.0f);
+
+            auto f0 = Vec((float)w1_w - w);
+            auto f1 = oneV - f0;
+            auto h0 = Vec((float)w1_h - h);
+            auto h1 = oneV - h0;
+
+            int s00 = MNNGridSampleComputeOffset(w0_h, w0_w, inH, inW, padMode);
+            int s01 = MNNGridSampleComputeOffset(w0_h, w1_w, inH, inW, padMode);
+            int s10 = MNNGridSampleComputeOffset(w1_h, w0_w, inH, inW, padMode);
+            int s11 = MNNGridSampleComputeOffset(w1_h, w1_w, inH, inW, padMode);
+
+            for (int k = 0; k < channelCUnit; ++k) {
+                auto o = Vec::load(outputPtr + k * outOffset + pack * ow);
+                if (s00 != -1) {
+                    auto i = Vec::load(inputPtr + k * inOffset + s00);
+                    auto diff = o * h0 * f0;
+                    Vec::save(inputPtr + k * inOffset + s00, diff + i);
+                }
+                if (s01 != -1) {
+                    auto i = Vec::load(inputPtr + k * inOffset + s01);
+                    auto diff = o * h0 * f1;
+                    Vec::save(inputPtr + k * inOffset + s01, diff + i);
+                }
+                if (s10 != -1) {
+                    auto i = Vec::load(inputPtr + k * inOffset + s10);
+                    auto diff = o * h1 * f0;
+                    Vec::save(inputPtr + k * inOffset + s10, diff + i);
+                }
+                if (s11 != -1) {
+                    auto i = Vec::load(inputPtr + k * inOffset + s11);
+                    auto diff = o * h1 * f1;
+                    Vec::save(inputPtr + k * inOffset + s11, diff + i);
+                }
+            }
+        }
+    }
+}

+ 28 - 28
source/backend/cpu/UnaryUtils.hpp

@@ -16,21 +16,21 @@ static void _unaryOp(void* outputPtr, const void* inputPtr, int elementSize) {
 }
 
 template <typename T>
-struct UnarySquare : std::unary_function<T, T> {
+struct UnarySquare {
     T operator()(const T &x) const {
         return x * x;
     }
 };
 
 template <typename T>
-struct UnaryRsqrt : std::unary_function<T, T> {
+struct UnaryRsqrt {
     T operator()(const T &x) const {
         return 1.f / sqrtf(x);
     }
 };
 
 template <typename T>
-struct UnarySqrt : std::unary_function<T, T> {
+struct UnarySqrt {
     T operator()(const T &x) const {
         return sqrtf(x);
     }
@@ -44,77 +44,77 @@ struct UnaryNeg {
 };
 
 template <typename T>
-struct UnaryExp : std::unary_function<T, T> {
+struct UnaryExp {
     T operator()(const T &x) const {
         return expf(x);
     }
 };
 
 template <typename T>
-struct UnaryAbs : std::unary_function<T, T> {
+struct UnaryAbs {
     T operator()(const T &x) const {
         return fabsf((float)x);
     }
 };
 
 template <typename T>
-struct UnaryCeil : std::unary_function<T, T> {
+struct UnaryCeil {
     T operator()(const T &x) const {
         return ceilf(x);
     }
 };
 template <typename T>
-struct UnaryRecipocal : std::unary_function<T, T> {
+struct UnaryRecipocal {
     T operator()(const T &x) const {
         return (T)1 / (x);
     }
 };
 template <typename T>
-struct UnaryLog1p : std::unary_function<T, T> {
+struct UnaryLog1p {
     T operator()(const T &x) const {
         return (T)logf((T)1 + (x));
     }
 };
 template <typename T>
-struct UnaryLog : std::unary_function<T, T> {
+struct UnaryLog {
     T operator()(const T &x) const {
         return (T)logf((T)(x));
     }
 };
 template <typename T>
-struct UnaryCos : std::unary_function<T, T> {
+struct UnaryCos {
     T operator()(const T &x) const {
         return (T)cosf((T)(x));
     }
 };
 template <typename T>
-struct UnarySin : std::unary_function<T, T> {
+struct UnarySin {
     T operator()(const T &x) const {
         return (T)sinf((T)(x));
     }
 };
 template <typename T>
-struct UnaryTan : std::unary_function<T, T> {
+struct UnaryTan {
     T operator()(const T &x) const {
         return (T)tanf((T)(x));
     }
 };
 template <typename T>
-struct UnaryATan : std::unary_function<T, T> {
+struct UnaryATan {
     T operator()(const T &x) const {
         return (T)atanf((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryFloor : std::unary_function<T, T> {
+struct UnaryFloor {
     T operator()(const T &x) const {
         return (T)floor((T)(x));
     }
 };
 
 template <typename T>
-struct UnarySign : std::unary_function<T, T> {
+struct UnarySign {
     T operator()(const T &x) const {
         if (x > 0) {
             return 1;
@@ -127,7 +127,7 @@ struct UnarySign : std::unary_function<T, T> {
 };
 
 template <typename T>
-struct UnaryBNLL : std::unary_function<T, T> {
+struct UnaryBNLL {
     T operator()(const T &x) const {
         float r = x > 0 ? (x + log(1. + exp(-x))) : log(1. + exp(x));
         return (T)r;
@@ -135,41 +135,41 @@ struct UnaryBNLL : std::unary_function<T, T> {
 };
 
 template <typename T>
-struct UnaryAcosh : std::unary_function<T, T> {
+struct UnaryAcosh {
     T operator()(const T &x) const {
         return (T)acoshf((T)(x));
     }
 };
 
 template <typename T>
-struct UnarySinh : std::unary_function<T, T> {
+struct UnarySinh {
     T operator()(const T &x) const {
         return (T)sinhf((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryAsinh : std::unary_function<T, T> {
+struct UnaryAsinh {
     T operator()(const T &x) const {
         return (T)asinhf((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryAtanh : std::unary_function<T, T> {
+struct UnaryAtanh {
     T operator()(const T &x) const {
         return (T)atanhf((T)(x));
     }
 };
 template <typename T>
-struct UnaryRound : std::unary_function<T, T> {
+struct UnaryRound {
     T operator()(const T &x) const {
         return (T)roundf((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryCosh : std::unary_function<T, T> {
+struct UnaryCosh {
     T operator()(const T &x) const {
         return (T)coshf((T)(x));
     }
@@ -177,21 +177,21 @@ struct UnaryCosh : std::unary_function<T, T> {
 
 
 template <typename T>
-struct UnaryErf : std::unary_function<T, T> {
+struct UnaryErf {
     T operator()(const T &x) const {
         return erff(x);
     }
 };
 
 template <typename T>
-struct UnaryErfc : std::unary_function<T, T> {
+struct UnaryErfc {
     T operator()(const T &x) const {
         return erfc(x);
     }
 };
 
 template <typename T>
-struct UnaryErfinv : std::unary_function<T, T> {
+struct UnaryErfinv {
     // referenced from tensorflow
     const int kDegree = 9;
     const std::vector<float> w_less_than_5_constants = {
@@ -235,21 +235,21 @@ struct UnaryErfinv : std::unary_function<T, T> {
 };
 
 template <typename T>
-struct UnaryExpm1 : std::unary_function<T, T> {
+struct UnaryExpm1 {
     T operator()(const T &x) const {
         return (T)expm1((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryAsin : std::unary_function<T, T> {
+struct UnaryAsin {
     T operator()(const T &x) const {
         return (T)asin((T)(x));
     }
 };
 
 template <typename T>
-struct UnaryAcos : std::unary_function<T, T> {
+struct UnaryAcos {
     T operator()(const T &x) const {
         return (T)acos((T)(x));
     }

+ 34 - 0
source/backend/cpu/arm/CommonOptFunctionNeon.cpp

@@ -7,6 +7,7 @@
 
 extern "C" {
 void MNNTranspose32Bit4x4(int32_t* dstO, const int32_t* srcO, int32_t* dim);
+void MNNTranspose16Bit8x8(int16_t* dstO, const int16_t* srcO, int32_t* dim);
 }
 void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) {
     int w = dim[0];
@@ -40,6 +41,39 @@ void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) {
     }
 }
 
+void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) {
+    int w = dim[0];
+    int h = dim[1];
+    auto wC8 = w / 8;
+    auto hC8 = h / 8;
+    int srcStride = dim[2];
+    int dstStride = dim[3];
+    if (wC8 > 0 && hC8 > 0) {
+        MNNTranspose16Bit8x8(dstO, srcO, dim);
+    }
+
+    // Down
+    for (int i = hC8 * 8; i < h; ++i) {
+        auto si = srcO + i;
+        auto di = dstO + i * dstStride;
+        for (int j = 0; j < w; ++j) {
+            auto sj = si + j * srcStride;
+            auto dj = di + j;
+            *dj = *sj;
+        }
+    }
+    // Right
+    for (int i = 0; i < hC8 * 8; ++i) {
+        auto si = srcO + i;
+        auto di = dstO + i * dstStride;
+        for (int j = wC8 * 8; j < w; ++j) {
+            auto sj = si + j * srcStride;
+            auto dj = di + j;
+            *dj = *sj;
+        }
+    }
+}
+
 #ifndef MNN_USE_NEON
 
 void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) {

+ 56 - 31
source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S

@@ -41,10 +41,12 @@ ldr r7, [sp, #44]
 
 vpush {q4-q7}
 
-ldr r12, [r3]
-vdup.s32 q13, r12     // scale
-ldr r4, [r3, #4]
-vdup.s32 q14, r4
+ldr r12, [r4]
+vdup.f32 q13, r12     // scale
+ldr lr, [r4, #4]
+vdup.f32 q14, lr
+ldr lr, [r4, #8]
+vdup.f32 q15, lr
 
 ldr lr, [r5, #8]
 ldr r8, [lr, #0]
@@ -104,31 +106,50 @@ L4Loop:
 
     vmovl.s16 q3, d12
     vmovl.s16 q12, d13
-    vmovl.s16 q15, d14
+    vmovl.s16 q1, d14
     vmovl.s16 q4, d15
 
-    vmulq.s32 q8, q8, q13
-    vmulq.s32 q9, q9, q13
-    vmulq.s32 q10, q10, q13
-    vmulq.s32 q11, q11, q13
-
-    vmulq.s32 q3, q3, q14
-    vmulq.s32 q12, q12, q14
-    vmulq.s32 q15, q15, q14
-    vmulq.s32 q4, q4, q14
-
-    vaddq.s32 q8, q8, q3
-    vaddq.s32 q9, q9, q12
-    vaddq.s32 q10, q10, q15
-    vaddq.s32 q11, q11, q4
-
+    vcvt.f32.s32 q8, q8
+    vcvt.f32.s32 q9, q9
+    vcvt.f32.s32 q10, q10
+    vcvt.f32.s32 q11, q11
+
+    vcvt.f32.s32 q3, q3
+    vcvt.f32.s32 q12, q12
+    vcvt.f32.s32 q1, q1
+    vcvt.f32.s32 q4, q4
+
+    vmul.f32 q8, q8, q13
+    vmul.f32 q9, q9, q13
+    vmul.f32 q10, q10, q13
+    vmul.f32 q11, q11, q13
+
+    vmul.f32 q3, q3, q14
+    vmul.f32 q12, q12, q14
+    vmul.f32 q1, q1, q14
+    vmul.f32 q4, q4, q14
+
+    vadd.f32 q8, q8, q3
+    vadd.f32 q9, q9, q12
+    vadd.f32 q10, q10, q1
+    vadd.f32 q11, q11, q4
+
+    vmul.f32 q8, q8, q15
+    vmul.f32 q9, q9, q15
+    vmul.f32 q10, q10, q15
+    vmul.f32 q11, q11, q15
+
+    vcvt.s32.f32 q8, q8
+    vcvt.s32.f32 q9, q9
+    vcvt.s32.f32 q10, q10
+    vcvt.s32.f32 q11, q11
+
+    vqmovn.s32 d6, q8
+    vqmovn.s32 d7, q9
+    vqmovn.s32 d8, q10
+    vqmovn.s32 d9, q11
     vdup.8 q12, r3
-    vdup.8 q15, r11
-
-    vqshrn.s32 d6, q8, #16
-    vqshrn.s32 d7, q9, #16
-    vqshrn.s32 d8, q10, #16
-    vqshrn.s32 d9, q11, #16
+    vdup.8 q1, r11
 
     vaddw.s8 q3, q3, d4
     vaddw.s8 q4, q4, d4
@@ -136,7 +157,7 @@ L4Loop:
     vqmovn.s16 d12, q3
     vqmovn.s16 d13, q4
     vmax.s8 q6, q6, q12
-    vmin.s8 q6, q6, q15
+    vmin.s8 q6, q6, q1
     cmp r6, #4
     vst1.32 {q6}, [r0]!
     bge L4Loop
@@ -174,15 +195,19 @@ L1Loop:
     vmovl.s8 q3, d6
     vsubw.s8 q3, q3, d0
     vmovl.s16 q3, d6
-    vmulq.s32 q3, q3, q13
+    vcvt.f32.s32 q3, q3
+    vmul.f32 q3, q3, q13
 
     vmovl.s8 q5, d8
     vsubw.s8 q5, q5, d1
     vmovl.s16 q6, d10
-    vmulq.s32 q6, q6, q14
+    vcvt.f32.s32 q6, q6
+    vmul.f32 q6, q6, q14
 
-    vaddq.s32 q3, q3, q6
-    vqshrn.s32 d6, q3, #16
+    vadd.f32 q3, q3, q6
+    vmul.f32 q3, q3, q15
+    vcvt.s32.f32 q3, q3
+    vqmovn.s32 d6, q3
     vaddw.s8 q3, q3, d4
     vqmovn.s16 d6, q3
     vmax.s8 d6, d6, d20

+ 2 - 2
source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S

@@ -149,7 +149,7 @@ L4Loop:
     vqmovn.s32 d8, q10
     vqmovn.s32 d9, q11
     vdup.8 q12, r3
-    vdup.8 q15, r11
+    vdup.8 q1, r11
 
     vaddw.s8 q3, q3, d4
     vaddw.s8 q4, q4, d4
@@ -157,7 +157,7 @@ L4Loop:
     vqmovn.s16 d12, q3
     vqmovn.s16 d13, q4
     vmax.s8 q6, q6, q12
-    vmin.s8 q6, q6, q15
+    vmin.s8 q6, q6, q1
     cmp r6, #4
     vst1.32 {q6}, [r0]!
     bge L4Loop

+ 56 - 30
source/backend/cpu/arm/arm32/MNNBinarySubInt8.S

@@ -41,10 +41,12 @@ ldr r7, [sp, #44]
 
 vpush {q4-q7}
 
-ldr r12, [r3]
-vdup.s32 q13, r12     // scale
-ldr r4, [r3, #4]
-vdup.s32 q14, r4
+ldr r12, [r4]
+vdup.f32 q13, r12     // scale
+ldr lr, [r4, #4]
+vdup.f32 q14, lr
+ldr lr, [r4, #8]
+vdup.f32 q15, lr
 
 ldr lr, [r5, #8]
 ldr r8, [lr, #0]
@@ -104,30 +106,50 @@ L4Loop:
 
     vmovl.s16 q3, d12
     vmovl.s16 q12, d13
-    vmovl.s16 q15, d14
+    vmovl.s16 q1, d14
     vmovl.s16 q4, d15
 
-    vmulq.s32 q8, q8, q13
-    vmulq.s32 q9, q9, q13
-    vmulq.s32 q10, q10, q13
-    vmulq.s32 q11, q11, q13
-
-    vmulq.s32 q3, q3, q14
-    vmulq.s32 q12, q12, q14
-    vmulq.s32 q15, q15, q14
-    vmulq.s32 q4, q4, q14
-
-    vsub.s32 q8, q8, q3
-    vsub.s32 q9, q9, q12
-    vsub.s32 q10, q10, q15
-    vsub.s32 q11, q11, q4
+    vcvt.f32.s32 q8, q8
+    vcvt.f32.s32 q9, q9
+    vcvt.f32.s32 q10, q10
+    vcvt.f32.s32 q11, q11
+
+    vcvt.f32.s32 q3, q3
+    vcvt.f32.s32 q12, q12
+    vcvt.f32.s32 q1, q1
+    vcvt.f32.s32 q4, q4
+
+    vmul.f32 q8, q8, q13
+    vmul.f32 q9, q9, q13
+    vmul.f32 q10, q10, q13
+    vmul.f32 q11, q11, q13
+
+    vmul.f32 q3, q3, q14
+    vmul.f32 q12, q12, q14
+    vmul.f32 q1, q1, q14
+    vmul.f32 q4, q4, q14
+
+    vsub.f32 q8, q8, q3
+    vsub.f32 q9, q9, q12
+    vsub.f32 q10, q10, q1
+    vsub.f32 q11, q11, q4
+
+    vmul.f32 q8, q8, q15
+    vmul.f32 q9, q9, q15
+    vmul.f32 q10, q10, q15
+    vmul.f32 q11, q11, q15
+
+    vcvt.s32.f32 q8, q8
+    vcvt.s32.f32 q9, q9
+    vcvt.s32.f32 q10, q10
+    vcvt.s32.f32 q11, q11
+
+    vqmovn.s32 d6, q8
+    vqmovn.s32 d7, q9
+    vqmovn.s32 d8, q10
+    vqmovn.s32 d9, q11
     vdup.8 q12, r3
-    vdup.8 q15, r11
-
-    vqshrn.s32 d6, q8, #16
-    vqshrn.s32 d7, q9, #16
-    vqshrn.s32 d8, q10, #16
-    vqshrn.s32 d9, q11, #16
+    vdup.8 q1, r11
 
     vaddw.s8 q3, q3, d4
     vaddw.s8 q4, q4, d4
@@ -135,7 +157,7 @@ L4Loop:
     vqmovn.s16 d12, q3
     vqmovn.s16 d13, q4
     vmax.s8 q6, q6, q12
-    vmin.s8 q6, q6, q15
+    vmin.s8 q6, q6, q1
     cmp r6, #4
     vst1.32 {q6}, [r0]!
     bge L4Loop
@@ -173,15 +195,19 @@ L1Loop:
     vmovl.s8 q3, d6
     vsubw.s8 q3, q3, d0
     vmovl.s16 q3, d6
-    vmulq.s32 q3, q3, q13
+    vcvt.f32.s32 q3, q3
+    vmul.f32 q3, q3, q13
 
     vmovl.s8 q5, d8
     vsubw.s8 q5, q5, d1
     vmovl.s16 q6, d10
-    vmulq.s32 q6, q6, q14
+    vcvt.f32.s32 q6, q6
+    vmul.f32 q6, q6, q14
 
-    vsub.s32 q3, q3, q6
-    vqshrn.s32 d6, q3, #16
+    vsub.f32 q3, q3, q6
+    vmul.f32 q3, q3, q15
+    vcvt.s32.f32 q3, q3
+    vqmovn.s32 d6, q3
     vaddw.s8 q3, q3, d4
     vqmovn.s16 d6, q3
     vmax.s8 d6, d6, d20

+ 5 - 2
source/backend/cpu/arm/arm32/MNNGelu.S

@@ -45,6 +45,9 @@ vdup.32 q10, r8        //q10: [28.f]x4
 vdup.32 q9, r10        //q9: [3150.f]x4
 vdup.32 q8, r11        //q8: [62370.f]x4
 
+mov r4, #5.0
+mov r5, #-5.0
+
 GeluZLoop:
 
 vld1.32 {q0, q1}, [r1]!   // q0, q1: fp32x4
@@ -63,8 +66,8 @@ vmul.f32 q2, q2, q14 // value
 vmul.f32 q3, q3, q14 // value
 
 // if value > 5, then value=5; if value<-5, then value=-5
-vmov.f32 q7, #5.0
-vmov.f32 q6, #-5.0
+vdup.32 q7, r4
+vdup.32 q6, r5
 vmax.f32 q2, q2, q6
 vmax.f32 q3, q3, q6
 vmin.f32 q2, q2, q7

+ 135 - 0
source/backend/cpu/arm/arm32/MNNTranspose16Bit8x8.S

@@ -0,0 +1,135 @@
+//
+//  MNNTranspose16Bit8x8.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+#ifdef __arm__
+#ifndef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+asm_function MNNTranspose16Bit8x8
+//void MNNTranspose16Bit8x8(int16_t* dstO, const int16_t* srcO, int* dim)
+//Auto: r0: dstO, r1:srcO, r2: dim
+
+push {r4-r8, lr} // avoid to touch platform-register r-9
+ldr r4, [r2, #0]
+ldr r5, [r2, #4]
+ldr r6, [r2, #8]
+ldr r7, [r2, #12]
+
+// r4, r5 -> wC8, hC8
+lsr r4, r4, #3
+lsr r5, r5, #3
+
+// r6, r7 -> srcStride * sizeof(half), dstStride * sizeof(half)
+lsl r6, r6, #1
+lsl r7, r7, #1
+
+
+LoopY:
+    mov r2, r4
+    mov r8, r0
+    mov lr, r1
+    LoopX:
+        /*
+        after vld1.16
+        [ 0,  1,  2,  3,  4,  5,  6,  7]
+        [ 8,  9, 10, 11, 12, 13, 14, 15]
+        [16, 17, 18, 19, 20, 21, 22, 23]
+        [24, 25, 26, 27, 28, 29, 30, 31]
+        [32, 33, 34, 35, 36, 37, 38, 39]
+        [40, 41, 42, 43, 44, 45, 46, 47]
+        [48, 49, 50, 51, 52, 53, 54, 55]
+        [56, 57, 58, 59, 60, 61, 62, 63]
+        */
+        vld1.16 {q0}, [r1], r6
+        vld1.16 {q1}, [r1], r6
+        vld1.16 {q2}, [r1], r6
+        vld1.16 {q3}, [r1], r6
+        vld1.16 {q4}, [r1], r6
+        vld1.16 {q5}, [r1], r6
+        vld1.16 {q6}, [r1], r6
+        vld1.16 {q7}, [r1], r6
+
+        /*
+        after vtrn.16
+        [ 0,  8,  2, 10,  4, 12,  6, 14]
+        [ 1,  9,  3, 11,  5, 13,  7, 15]
+        [16, 24, 18, 26, 20, 28, 22, 30]
+        [17, 25, 19, 27, 21, 29, 23, 31]
+        [32, 40, 34, 42, 36, 44, 38, 46]
+        [33, 41, 35, 43, 37, 45, 39, 47]
+        [48, 56, 50, 58, 52, 60, 54, 62]
+        [49, 57, 51, 59, 53, 61, 55, 63]
+        */
+        vtrn.16 q0, q1
+        vtrn.16 q2, q3
+        vtrn.16 q4, q5
+        vtrn.16 q6, q7
+
+        /*
+        after vtrn.32
+        [ 0,  8, 16, 24,  4, 12,  20, 28]
+        [ 1,  9, 17, 25,  5, 13, 21, 29]
+        [ 2, 10, 18, 26,  6, 14, 22, 30]
+        [ 3, 11, 19, 27,  7, 15, 23, 31]
+        [32, 40, 48, 56, 36, 44, 52, 60]
+        [33, 41, 49, 57, 37, 45, 53, 61]
+        [34, 42, 50, 58, 38, 46, 54, 62]
+        [35, 43, 51, 59, 39, 47, 55, 63]
+        */
+        vtrn.32 q0, q2
+        vtrn.32 q1, q3
+        vtrn.32 q4, q6
+        vtrn.32 q5, q7
+
+        /*
+        after vswp
+        [ 0,  8, 16, 24, 32, 40, 48, 56]
+        [ 1,  9, 17, 25, 33, 41, 49, 57]
+        [ 2, 10, 18, 26, 34, 42, 50, 58]
+        [ 3, 11, 19, 27, 35, 43, 51, 59]
+        [ 4, 12, 20, 28, 36, 44, 52, 60]
+        [ 5, 13, 21, 29, 37, 45, 53, 61]
+        [ 6, 14, 22, 30, 38, 46, 54, 62]
+        [ 7, 15, 23, 31, 39, 47, 55, 63]
+        */
+        vswp d1, d8
+        vswp d3, d10
+        vswp d5, d12
+        vswp d7, d14
+
+        mov r12, r0
+
+        vst1.16 {q0}, [r12], r7
+        vst1.16 {q1}, [r12], r7
+        vst1.16 {q2}, [r12], r7
+        vst1.16 {q3}, [r12], r7
+        vst1.16 {q4}, [r12], r7
+        vst1.16 {q5}, [r12], r7
+        vst1.16 {q6}, [r12], r7
+        vst1.16 {q7}, [r12], r7
+
+        add r0, r0, #16 // 4 * sizeof(float)
+
+        subs r2, r2, #1
+        bne LoopX
+
+
+    lsl r12, r7, #3
+    subs r5, r5, #1
+    add r1, lr, #16 // 8 * sizeof(half)
+    add r0, r8, r12
+    bne LoopY
+
+End:
+
+pop {r4-r8, pc}
+
+#endif
+#endif

+ 5 - 2
source/backend/cpu/arm/arm32/bf16/MNNGelu_BF16.S

@@ -45,6 +45,9 @@ vdup.32 q10, r8        //q10: [28.f]x4
 vdup.32 q9, r10        //q9: [3150.f]x4
 vdup.32 q8, r11        //q8: [62370.f]x4
 
+mov r4, #5.0
+mov r5, #-5.0
+
 GeluZLoop:
 
 vld1.16 q0, [r1]!   // q0: 8* sizeof(int16_t)
@@ -65,8 +68,8 @@ vadd.f32 q3, q3, q1
 vmul.f32 q2, q2, q14
 vmul.f32 q3, q3, q14
 
-vmov.f32 q7, #5.0
-vmov.f32 q6, #-5.0
+vdup.32 q7, r4
+vdup.32 q6, r5
 vmax.f32 q2, q2, q6
 vmax.f32 q3, q3, q6
 vmin.f32 q2, q2, q7

+ 123 - 60
source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S

@@ -37,10 +37,12 @@ stp d8,  d9,  [sp, #48]
 cmp x6, #0
 beq End
 
-ldr w4, [x3, #8]
-ldr w3, [x3]
+ldr w3, [x4]
+ldr w10, [x4, #8]
+ldr w4, [x4, #4]
 mov v0.s[0], w3
 mov v0.s[1], w4
+mov v0.s[2], w10
 
 ldr x8, [x5, #16]
 ldr x9, [x8, #8]   // input1 zeroPoint
@@ -128,43 +130,79 @@ L8Loop:
     sxtl  v29.4s, v14.4h
     sxtl2 v30.4s, v14.8h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v16.4s, v16.4s, v0.s[0]
-    mul v17.4s, v17.4s, v0.s[0]
-    mul v18.4s, v18.4s, v0.s[0]
-    mul v19.4s, v19.4s, v0.s[0]
-    mul v20.4s, v20.4s, v0.s[0]
-    mul v21.4s, v21.4s, v0.s[0]
-    mul v22.4s, v22.4s, v0.s[0]
-
-    mul v23.4s, v23.4s, v0.s[1]
-    mul v24.4s, v24.4s, v0.s[1]
-    mul v25.4s, v25.4s, v0.s[1]
-    mul v26.4s, v26.4s, v0.s[1]
-    mul v27.4s, v27.4s, v0.s[1]
-    mul v28.4s, v28.4s, v0.s[1]
-    mul v29.4s, v29.4s, v0.s[1]
-    mul v30.4s, v30.4s, v0.s[1]
+    scvtf v15.4s, v15.4s 
+    scvtf v16.4s, v16.4s 
+    scvtf v17.4s, v17.4s 
+    scvtf v18.4s, v18.4s 
+    scvtf v19.4s, v19.4s 
+    scvtf v20.4s, v20.4s 
+    scvtf v21.4s, v21.4s 
+    scvtf v22.4s, v22.4s 
+
+    scvtf v23.4s, v23.4s    
+    scvtf v24.4s, v24.4s
+    scvtf v25.4s, v25.4s
+    scvtf v26.4s, v26.4s
+    scvtf v27.4s, v27.4s
+    scvtf v28.4s, v28.4s
+    scvtf v29.4s, v29.4s
+    scvtf v30.4s, v30.4s
+
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v16.4s, v16.4s, v0.s[0]
+    fmul v17.4s, v17.4s, v0.s[0]
+    fmul v18.4s, v18.4s, v0.s[0]
+    fmul v19.4s, v19.4s, v0.s[0]
+    fmul v20.4s, v20.4s, v0.s[0]
+    fmul v21.4s, v21.4s, v0.s[0]
+    fmul v22.4s, v22.4s, v0.s[0]
+
+    fmul v23.4s, v23.4s, v0.s[1]
+    fmul v24.4s, v24.4s, v0.s[1]
+    fmul v25.4s, v25.4s, v0.s[1]
+    fmul v26.4s, v26.4s, v0.s[1]
+    fmul v27.4s, v27.4s, v0.s[1]
+    fmul v28.4s, v28.4s, v0.s[1]
+    fmul v29.4s, v29.4s, v0.s[1]
+    fmul v30.4s, v30.4s, v0.s[1]
     dup v11.16b, w11
     dup v12.16b, w12
 
-    add v15.4s, v15.4s, v23.4s
-    add v16.4s, v16.4s, v24.4s 
-    add v17.4s, v17.4s, v25.4s 
-    add v18.4s, v18.4s, v26.4s 
-    add v19.4s, v19.4s, v27.4s 
-    add v20.4s, v20.4s, v28.4s 
-    add v21.4s, v21.4s, v29.4s 
-    add v22.4s, v22.4s, v30.4s 
-
-    sqrshrn  v1.4h, v15.4s, #16
-    sqrshrn2 v1.8h, v16.4s, #16
-    sqrshrn  v2.4h, v17.4s, #16
-    sqrshrn2 v2.8h, v18.4s, #16
-    sqrshrn  v3.4h, v19.4s, #16
-    sqrshrn2 v3.8h, v20.4s, #16
-    sqrshrn  v4.4h, v21.4s, #16
-    sqrshrn2 v4.8h, v22.4s, #16
+    fadd v15.4s, v15.4s, v23.4s
+    fadd v16.4s, v16.4s, v24.4s 
+    fadd v17.4s, v17.4s, v25.4s 
+    fadd v18.4s, v18.4s, v26.4s 
+    fadd v19.4s, v19.4s, v27.4s 
+    fadd v20.4s, v20.4s, v28.4s 
+    fadd v21.4s, v21.4s, v29.4s 
+    fadd v22.4s, v22.4s, v30.4s
+
+    fmul v15.4s, v15.4s, v0.s[2]
+    fmul v16.4s, v16.4s, v0.s[2]
+    fmul v17.4s, v17.4s, v0.s[2]
+    fmul v18.4s, v18.4s, v0.s[2]
+    fmul v19.4s, v19.4s, v0.s[2]
+    fmul v20.4s, v20.4s, v0.s[2]
+    fmul v21.4s, v21.4s, v0.s[2]
+    fmul v22.4s, v22.4s, v0.s[2]
+
+    fcvtas v15.4s, v15.4s
+    fcvtas v16.4s, v16.4s
+    fcvtas v17.4s, v17.4s
+    fcvtas v18.4s, v18.4s
+    fcvtas v19.4s, v19.4s
+    fcvtas v20.4s, v20.4s
+    fcvtas v21.4s, v21.4s
+    fcvtas v22.4s, v22.4s
+
+    sqxtn v1.4h, v15.4s
+    sqxtn2 v1.8h, v16.4s
+    sqxtn v2.4h, v17.4s
+    sqxtn2 v2.8h, v18.4s
+    sqxtn v3.4h, v19.4s
+    sqxtn2 v3.8h, v20.4s
+    sqxtn v4.4h, v21.4s
+    sqxtn2 v4.8h, v22.4s
 
     cmp w10, #0
     beq SQXTN_S8
@@ -248,25 +286,45 @@ L4Loop:
     sxtl  v25.4s, v12.4h
     sxtl2 v26.4s, v12.8h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v16.4s, v16.4s, v0.s[0]
-    mul v17.4s, v17.4s, v0.s[0]
-    mul v18.4s, v18.4s, v0.s[0]
-
-    mul v23.4s, v23.4s, v0.s[1]
-    mul v24.4s, v24.4s, v0.s[1]
-    mul v25.4s, v25.4s, v0.s[1]
-    mul v26.4s, v26.4s, v0.s[1]
-
-    add v15.4s, v15.4s, v23.4s
-    add v16.4s, v16.4s, v24.4s 
-    add v17.4s, v17.4s, v25.4s 
-    add v18.4s, v18.4s, v26.4s
-
-    sqrshrn  v1.4h, v15.4s, #16
-    sqrshrn2 v1.8h, v16.4s, #16
-    sqrshrn  v2.4h, v17.4s, #16
-    sqrshrn2 v2.8h, v18.4s, #16
+    scvtf v15.4s, v15.4s 
+    scvtf v16.4s, v16.4s 
+    scvtf v17.4s, v17.4s 
+    scvtf v18.4s, v18.4s
+
+    scvtf v23.4s, v23.4s    
+    scvtf v24.4s, v24.4s
+    scvtf v25.4s, v25.4s
+    scvtf v26.4s, v26.4s
+
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v16.4s, v16.4s, v0.s[0]
+    fmul v17.4s, v17.4s, v0.s[0]
+    fmul v18.4s, v18.4s, v0.s[0]
+
+    fmul v23.4s, v23.4s, v0.s[1]
+    fmul v24.4s, v24.4s, v0.s[1]
+    fmul v25.4s, v25.4s, v0.s[1]
+    fmul v26.4s, v26.4s, v0.s[1]
+
+    fadd v15.4s, v15.4s, v23.4s
+    fadd v16.4s, v16.4s, v24.4s 
+    fadd v17.4s, v17.4s, v25.4s 
+    fadd v18.4s, v18.4s, v26.4s
+
+    fmul v15.4s, v15.4s, v0.s[2]
+    fmul v16.4s, v16.4s, v0.s[2]
+    fmul v17.4s, v17.4s, v0.s[2]
+    fmul v18.4s, v18.4s, v0.s[2]
+
+    fcvtas v15.4s, v15.4s
+    fcvtas v16.4s, v16.4s
+    fcvtas v17.4s, v17.4s
+    fcvtas v18.4s, v18.4s
+
+    sqxtn v1.4h, v15.4s
+    sqxtn2 v1.8h, v16.4s
+    sqxtn v2.4h, v17.4s
+    sqxtn2 v2.8h, v18.4s
 
     cmp w10, #0
     beq L4_SQXTN_S8
@@ -330,12 +388,17 @@ L1Loop:
     sxtl  v15.4s, v7.4h
     sxtl  v23.4s, v11.4h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v23.4s, v23.4s, v0.s[1]
+    scvtf v15.4s, v15.4s
+    scvtf v23.4s, v23.4s
 
-    add v15.4s, v15.4s, v23.4s
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v23.4s, v23.4s, v0.s[1]
 
-    sqrshrn  v1.4h, v15.4s, #16
+    fadd v15.4s, v15.4s, v23.4s
+    fmul v15.4s, v15.4s, v0.s[2]
+    fcvtas v15.4s, v15.4s
+    sqxtn v1.4h, v15.4s
+    
 
     cmp w10, #0
     beq L1_SQXTN_S8
@@ -345,7 +408,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 1 - 1
source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S

@@ -345,7 +345,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 1 - 1
source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S

@@ -344,7 +344,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 4 - 1
source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S

@@ -194,6 +194,9 @@ L8Loop:
     fcvtas v21.4s, v21.4s
     fcvtas v22.4s, v22.4s
 
+    dup v11.16b, w11
+    dup v12.16b, w12
+
     sqxtn v1.4h, v15.4s
     sqxtn2 v1.8h, v16.4s
     sqxtn v2.4h, v17.4s
@@ -408,7 +411,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 1 - 1
source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S

@@ -425,7 +425,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 125 - 60
source/backend/cpu/arm/arm64/MNNBinarySubInt8.S

@@ -36,10 +36,12 @@ stp d8,  d9,  [sp, #48]
 cmp x6, #0
 beq End
 
-ldr w4, [x3, #8]
-ldr w3, [x3]
+ldr w3, [x4]
+ldr w10, [x4, #8]
+ldr w4, [x4, #4]
 mov v0.s[0], w3
 mov v0.s[1], w4
+mov v0.s[2], w10
 
 ldr x8, [x5, #16]
 ldr x9, [x8, #8]   // input1 zeroPoint
@@ -127,43 +129,79 @@ L8Loop:
     sxtl  v29.4s, v14.4h
     sxtl2 v30.4s, v14.8h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v16.4s, v16.4s, v0.s[0]
-    mul v17.4s, v17.4s, v0.s[0]
-    mul v18.4s, v18.4s, v0.s[0]
-    mul v19.4s, v19.4s, v0.s[0]
-    mul v20.4s, v20.4s, v0.s[0]
-    mul v21.4s, v21.4s, v0.s[0]
-    mul v22.4s, v22.4s, v0.s[0]
-
-    mul v23.4s, v23.4s, v0.s[1]
-    mul v24.4s, v24.4s, v0.s[1]
-    mul v25.4s, v25.4s, v0.s[1]
-    mul v26.4s, v26.4s, v0.s[1]
-    mul v27.4s, v27.4s, v0.s[1]
-    mul v28.4s, v28.4s, v0.s[1]
-    mul v29.4s, v29.4s, v0.s[1]
-    mul v30.4s, v30.4s, v0.s[1]
+    scvtf v15.4s, v15.4s 
+    scvtf v16.4s, v16.4s 
+    scvtf v17.4s, v17.4s 
+    scvtf v18.4s, v18.4s 
+    scvtf v19.4s, v19.4s 
+    scvtf v20.4s, v20.4s 
+    scvtf v21.4s, v21.4s 
+    scvtf v22.4s, v22.4s 
+
+    scvtf v23.4s, v23.4s    
+    scvtf v24.4s, v24.4s
+    scvtf v25.4s, v25.4s
+    scvtf v26.4s, v26.4s
+    scvtf v27.4s, v27.4s
+    scvtf v28.4s, v28.4s
+    scvtf v29.4s, v29.4s
+    scvtf v30.4s, v30.4s
+
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v16.4s, v16.4s, v0.s[0]
+    fmul v17.4s, v17.4s, v0.s[0]
+    fmul v18.4s, v18.4s, v0.s[0]
+    fmul v19.4s, v19.4s, v0.s[0]
+    fmul v20.4s, v20.4s, v0.s[0]
+    fmul v21.4s, v21.4s, v0.s[0]
+    fmul v22.4s, v22.4s, v0.s[0]
+
+    fmul v23.4s, v23.4s, v0.s[1]
+    fmul v24.4s, v24.4s, v0.s[1]
+    fmul v25.4s, v25.4s, v0.s[1]
+    fmul v26.4s, v26.4s, v0.s[1]
+    fmul v27.4s, v27.4s, v0.s[1]
+    fmul v28.4s, v28.4s, v0.s[1]
+    fmul v29.4s, v29.4s, v0.s[1]
+    fmul v30.4s, v30.4s, v0.s[1]
     dup v11.16b, w11
     dup v12.16b, w12
 
-    sqsub v15.4s, v15.4s, v23.4s
-    sqsub v16.4s, v16.4s, v24.4s 
-    sqsub v17.4s, v17.4s, v25.4s 
-    sqsub v18.4s, v18.4s, v26.4s 
-    sqsub v19.4s, v19.4s, v27.4s 
-    sqsub v20.4s, v20.4s, v28.4s 
-    sqsub v21.4s, v21.4s, v29.4s 
-    sqsub v22.4s, v22.4s, v30.4s 
-
-    sqrshrn  v1.4h, v15.4s, #16
-    sqrshrn2 v1.8h, v16.4s, #16
-    sqrshrn  v2.4h, v17.4s, #16
-    sqrshrn2 v2.8h, v18.4s, #16
-    sqrshrn  v3.4h, v19.4s, #16
-    sqrshrn2 v3.8h, v20.4s, #16
-    sqrshrn  v4.4h, v21.4s, #16
-    sqrshrn2 v4.8h, v22.4s, #16
+    fsub v15.4s, v15.4s, v23.4s
+    fsub v16.4s, v16.4s, v24.4s 
+    fsub v17.4s, v17.4s, v25.4s 
+    fsub v18.4s, v18.4s, v26.4s 
+    fsub v19.4s, v19.4s, v27.4s 
+    fsub v20.4s, v20.4s, v28.4s 
+    fsub v21.4s, v21.4s, v29.4s 
+    fsub v22.4s, v22.4s, v30.4s 
+
+    fmul v15.4s, v15.4s, v0.s[2]
+    fmul v16.4s, v16.4s, v0.s[2]
+    fmul v17.4s, v17.4s, v0.s[2]
+    fmul v18.4s, v18.4s, v0.s[2]
+    fmul v19.4s, v19.4s, v0.s[2]
+    fmul v20.4s, v20.4s, v0.s[2]
+    fmul v21.4s, v21.4s, v0.s[2]
+    fmul v22.4s, v22.4s, v0.s[2]
+
+    fcvtas v15.4s, v15.4s
+    fcvtas v16.4s, v16.4s
+    fcvtas v17.4s, v17.4s
+    fcvtas v18.4s, v18.4s
+    fcvtas v19.4s, v19.4s
+    fcvtas v20.4s, v20.4s
+    fcvtas v21.4s, v21.4s
+    fcvtas v22.4s, v22.4s
+
+    sqxtn v1.4h, v15.4s
+    sqxtn2 v1.8h, v16.4s
+    sqxtn v2.4h, v17.4s
+    sqxtn2 v2.8h, v18.4s
+    sqxtn v3.4h, v19.4s
+    sqxtn2 v3.8h, v20.4s
+    sqxtn v4.4h, v21.4s
+    sqxtn2 v4.8h, v22.4s
 
     cmp w10, #0
     beq SQXTN_S8
@@ -247,25 +285,45 @@ L4Loop:
     sxtl  v25.4s, v12.4h
     sxtl2 v26.4s, v12.8h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v16.4s, v16.4s, v0.s[0]
-    mul v17.4s, v17.4s, v0.s[0]
-    mul v18.4s, v18.4s, v0.s[0]
-
-    mul v23.4s, v23.4s, v0.s[1]
-    mul v24.4s, v24.4s, v0.s[1]
-    mul v25.4s, v25.4s, v0.s[1]
-    mul v26.4s, v26.4s, v0.s[1]
-
-    sqsub v15.4s, v15.4s, v23.4s
-    sqsub v16.4s, v16.4s, v24.4s 
-    sqsub v17.4s, v17.4s, v25.4s 
-    sqsub v18.4s, v18.4s, v26.4s
-
-    sqrshrn  v1.4h, v15.4s, #16
-    sqrshrn2 v1.8h, v16.4s, #16
-    sqrshrn  v2.4h, v17.4s, #16
-    sqrshrn2 v2.8h, v18.4s, #16
+    scvtf v15.4s, v15.4s 
+    scvtf v16.4s, v16.4s 
+    scvtf v17.4s, v17.4s 
+    scvtf v18.4s, v18.4s
+
+    scvtf v23.4s, v23.4s    
+    scvtf v24.4s, v24.4s
+    scvtf v25.4s, v25.4s
+    scvtf v26.4s, v26.4s
+
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v16.4s, v16.4s, v0.s[0]
+    fmul v17.4s, v17.4s, v0.s[0]
+    fmul v18.4s, v18.4s, v0.s[0]
+
+    fmul v23.4s, v23.4s, v0.s[1]
+    fmul v24.4s, v24.4s, v0.s[1]
+    fmul v25.4s, v25.4s, v0.s[1]
+    fmul v26.4s, v26.4s, v0.s[1]
+
+    fsub v15.4s, v15.4s, v23.4s
+    fsub v16.4s, v16.4s, v24.4s 
+    fsub v17.4s, v17.4s, v25.4s 
+    fsub v18.4s, v18.4s, v26.4s
+
+    fmul v15.4s, v15.4s, v0.s[2]
+    fmul v16.4s, v16.4s, v0.s[2]
+    fmul v17.4s, v17.4s, v0.s[2]
+    fmul v18.4s, v18.4s, v0.s[2]
+
+    fcvtas v15.4s, v15.4s
+    fcvtas v16.4s, v16.4s
+    fcvtas v17.4s, v17.4s
+    fcvtas v18.4s, v18.4s
+
+    sqxtn v1.4h, v15.4s
+    sqxtn2 v1.8h, v16.4s
+    sqxtn v2.4h, v17.4s
+    sqxtn2 v2.8h, v18.4s
 
     cmp w10, #0
     beq L4_SQXTN_S8
@@ -329,12 +387,19 @@ L1Loop:
     sxtl  v15.4s, v7.4h
     sxtl  v23.4s, v11.4h
 
-    mul v15.4s, v15.4s, v0.s[0]
-    mul v23.4s, v23.4s, v0.s[1]
+    scvtf v15.4s, v15.4s
+    scvtf v23.4s, v23.4s
+
+    fmul v15.4s, v15.4s, v0.s[0]
+    fmul v23.4s, v23.4s, v0.s[1]
+
+    fsub v15.4s, v15.4s, v23.4s
+
+    fmul v15.4s, v15.4s, v0.s[2]
 
-    sqsub v15.4s, v15.4s, v23.4s
+    fcvtas v15.4s, v15.4s
 
-    sqrshrn  v1.4h, v15.4s, #16
+    sqxtn v1.4h, v15.4s
 
     cmp w10, #0
     beq L1_SQXTN_S8
@@ -344,7 +409,7 @@ L1Loop:
     L1_SQXTN_S8:
     sqxtn v5.8b, v1.8h
     smax v5.8b, v5.8b, v30.8b
-    smin v6.8b, v6.8b, v31.8b
+    smin v5.8b, v5.8b, v31.8b
     st1 {v5.s}[0], [x0], #4
 
     subs x6, x6, #1

+ 5 - 2
source/backend/cpu/arm/arm64/MNNGelu.S

@@ -45,8 +45,11 @@ dup v10.4s, w9        // v10: [28.f]x4
 dup v9.4s, w10        // v9: [3150.f]x4
 dup v8.4s, w11        // v8: [62370.f]x4
 
-fmov v30.4s, #5
-fmov v31.4s, #-5
+mov w4, #5.0
+mov w5, #-5.0
+
+dup v30.4s, w4
+dup v31.4s, w5
 
 GeluZLoop:
 

+ 1 - 1
source/backend/cpu/arm/arm64/MNNSoftmax.S

@@ -76,7 +76,7 @@ Loop_8:
     movk    w12, #15658, lsl #16
     movk    w10, #15914, lsl #16
     dup     v2.4s, v0.s[0]
-    movi    v1.2d, #0
+    movi    v1.16b, #0
     fmov    v3.4s, #1.0
     dup     v16.4s, w11
     dup     v17.4s, w12

+ 115 - 0
source/backend/cpu/arm/arm64/MNNTranspose16Bit8x8.S

@@ -0,0 +1,115 @@
+//
+//  MNNTranspose16Bit8x8.S
+//  MNN
+//
+//  Created by MNN on 2023/11/08.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro TRANSPOSE_8x8 s0, s1, s2, s3, s4, s5, s6, s7, d0, d1, d2, d3, d4, d5, d6, d7, t0, t1, t2, t3, t4, t5, t6, t7
+    zip1 \t0\().8h, \s0\().8h, \s1\().8h
+    zip2 \t1\().8h, \s0\().8h, \s1\().8h
+    zip1 \t2\().8h, \s2\().8h, \s3\().8h
+    zip2 \t3\().8h, \s2\().8h, \s3\().8h
+    zip1 \t4\().8h, \s4\().8h, \s5\().8h
+    zip2 \t5\().8h, \s4\().8h, \s5\().8h
+    zip1 \t6\().8h, \s6\().8h, \s7\().8h
+    zip2 \t7\().8h, \s6\().8h, \s7\().8h
+    zip1 \s0\().4s, \t0\().4s, \t2\().4s
+    zip2 \s1\().4s, \t0\().4s, \t2\().4s
+    zip1 \s2\().4s, \t1\().4s, \t3\().4s
+    zip2 \s3\().4s, \t1\().4s, \t3\().4s
+    zip1 \s4\().4s, \t4\().4s, \t6\().4s
+    zip2 \s5\().4s, \t4\().4s, \t6\().4s
+    zip1 \s6\().4s, \t5\().4s, \t7\().4s
+    zip2 \s7\().4s, \t5\().4s, \t7\().4s
+    zip1 \d0\().2d, \s0\().2d, \s4\().2d
+    zip2 \d1\().2d, \s0\().2d, \s4\().2d
+    zip1 \d2\().2d, \s1\().2d, \s5\().2d
+    zip2 \d3\().2d, \s1\().2d, \s5\().2d
+    zip1 \d4\().2d, \s2\().2d, \s6\().2d
+    zip2 \d5\().2d, \s2\().2d, \s6\().2d
+    zip1 \d6\().2d, \s3\().2d, \s7\().2d
+    zip2 \d7\().2d, \s3\().2d, \s7\().2d
+.endm
+
+asm_function MNNTranspose16Bit8x8
+//void MNNTranspose16Bit8x8(int16_t* dstO, const int16_t* srcO, int* dim)
+//Auto: x0: dstO, x1:srcO, x2: dim
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+mov x4, #0
+mov x5, #0
+mov x6, #0
+mov x7, #0
+ldr w4, [x2, #0]
+ldr w5, [x2, #4]
+ldr w6, [x2, #8]
+ldr w7, [x2, #12]
+
+// x4, x5 -> wC8, hC8
+lsr x4, x4, #3
+lsr x5, x5, #3
+
+// x6, x7 -> srcStride * sizeof(half), dstStride * sizeof(half)
+lsl x6, x6, #1
+lsl x7, x7, #1
+
+LoopY:
+    mov x2, x4
+    mov x8, x0
+    mov x9, x1
+    LoopX:
+        ld1 {v0.8h}, [x1], x6
+        ld1 {v1.8h}, [x1], x6
+        ld1 {v2.8h}, [x1], x6
+        ld1 {v3.8h}, [x1], x6
+        ld1 {v4.8h}, [x1], x6
+        ld1 {v5.8h}, [x1], x6
+        ld1 {v6.8h}, [x1], x6
+        ld1 {v7.8h}, [x1], x6
+
+        TRANSPOSE_8x8  v0,  v1,  v2,  v3,  v4,  v5,  v6,  v7, \
+                       v8,  v9, v10, v11, v12, v13, v14, v15, \
+                      v16, v17, v18, v19, v20, v21, v22, v23
+
+        mov x12, x0
+
+        st1 {v8.8h}, [x12], x7
+        st1 {v9.8h}, [x12], x7
+        st1 {v10.8h}, [x12], x7
+        st1 {v11.8h}, [x12], x7
+        st1 {v12.8h}, [x12], x7
+        st1 {v13.8h}, [x12], x7
+        st1 {v14.8h}, [x12], x7
+        st1 {v15.8h}, [x12], x7
+
+        add x0, x0, #16 // 8 * sizeof(half)
+
+        subs x2, x2, #1
+        bne LoopX
+
+
+    lsl x12, x7, #3 // 8 * dstStride
+    subs x5, x5, #1
+    add x1, x9, #16 // 8 * sizeof(half)
+    add x0, x8, x12
+    bne LoopY
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif

+ 5 - 2
source/backend/cpu/arm/arm64/bf16/MNNGelu_BF16.S

@@ -45,8 +45,11 @@ dup v10.4s, w9        // v10: [28.f]x4
 dup v9.4s, w10        // v9: [3150.f]x4
 dup v8.4s, w11        // v8: [62370.f]x4
 
-fmov v30.4s, #5
-fmov v31.4s, #-5
+mov w4, #5.0
+mov w5, #-5.0
+
+dup v30.4s, w4
+dup v31.4s, w5
 
 GeluZLoop:
 

+ 173 - 0
source/backend/cpu/arm/arm64/low_memory/MNNAbsMaxFP32.S

@@ -0,0 +1,173 @@
+//
+//  MNNAbsMaxFP32.S
+//
+//  Created by MNN on 2023/10/31.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+.macro Add d0, d1, d2, d3, z0, z1, z2, z3
+    fadd \d0\().4s, \d0\().4s, \z0\().4s
+    fadd \d1\().4s, \d1\().4s, \z1\().4s
+    fadd \d2\().4s, \d2\().4s, \z2\().4s
+    fadd \d3\().4s, \d3\().4s, \z3\().4s
+.endm
+
+.macro Abs z0, z1, z2, z3
+    fabs \z0\().4s, \z0\().4s
+    fabs \z1\().4s, \z1\().4s
+    fabs \z2\().4s, \z2\().4s
+    fabs \z3\().4s, \z3\().4s
+.endm
+
+.macro Max d0, d1, d2, d3, z0, z1, z2, z3
+    fmax \d0\().4s, \d0\().4s, \z0\().4s
+    fmax \d1\().4s, \d1\().4s, \z1\().4s
+    fmax \d2\().4s, \d2\().4s, \z2\().4s
+    fmax \d3\().4s, \d3\().4s, \z3\().4s
+.endm
+
+.macro TRANSPOSE_8x8 s0, s1, s2, s3, s4, s5, s6, s7, d0, d1, d2, d3, d4, d5, d6, d7, t0, t1, t2, t3, t4, t5, t6, t7
+    trn1 \t0\().4s, \s0\().4s, \s1\().4s
+    trn2 \t1\().4s, \s0\().4s, \s1\().4s
+    trn1 \t2\().4s, \s2\().4s, \s3\().4s
+    trn2 \t3\().4s, \s2\().4s, \s3\().4s
+    trn1 \t4\().4s, \s4\().4s, \s5\().4s
+    trn2 \t5\().4s, \s4\().4s, \s5\().4s
+    trn1 \t6\().4s, \s6\().4s, \s7\().4s
+    trn2 \t7\().4s, \s6\().4s, \s7\().4s
+
+    trn1 \d0\().2d, \t0\().2d, \t2\().2d
+    trn2 \d2\().2d, \t0\().2d, \t2\().2d
+    trn1 \d1\().2d, \t1\().2d, \t3\().2d
+    trn2 \d3\().2d, \t1\().2d, \t3\().2d
+    trn1 \d4\().2d, \t4\().2d, \t6\().2d
+    trn2 \d6\().2d, \t4\().2d, \t6\().2d
+    trn1 \d5\().2d, \t5\().2d, \t7\().2d
+    trn2 \d7\().2d, \t5\().2d, \t7\().2d
+.endm
+
+.macro ReduceSum d0, d1, z0, z1, z2, z3, z4, z5, z6, z7
+    fadd \d0\().4s, \z0\().4s, \z1\().4s
+    fadd \d0\().4s, \d0\().4s, \z2\().4s
+    fadd \d0\().4s, \d0\().4s, \z3\().4s
+    fadd \d1\().4s, \z4\().4s, \z5\().4s
+    fadd \d1\().4s, \d1\().4s, \z6\().4s
+    fadd \d1\().4s, \d1\().4s, \z7\().4s
+.endm
+
+.macro ReduceMax d0, d1, z0, z1, z2, z3, z4, z5, z6, z7
+    fmax \d0\().4s, \z0\().4s, \z1\().4s
+    fmax \d0\().4s, \d0\().4s, \z2\().4s
+    fmax \d0\().4s, \d0\().4s, \z3\().4s
+    fmax \d1\().4s, \z7\().4s, \z4\().4s
+    fmax \d1\().4s, \d1\().4s, \z5\().4s
+    fmax \d1\().4s, \d1\().4s, \z6\().4s
+.endm
+//void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack)
+asm_function MNNAbsMaxFP32
+
+// x0: source, x1:absmax, x2:src_depth_quad, x3:realSize
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x6, x3, #4 // src_step = batch * 4 * sizeof(float32_t) = batch << 4
+
+TILE_8:
+cmp x3, #8
+blt TILE_1
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+sub x8, x6, #64 // src_step
+
+//    sum: v0-7
+// absmax: v8-15
+ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x7], #64
+ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x7], x8
+fabs v8.4s, v0.4s
+fabs v9.4s, v1.4s
+fabs v10.4s, v2.4s
+fabs v11.4s, v3.4s
+fabs v12.4s, v4.4s
+fabs v13.4s, v5.4s
+fabs v14.4s, v6.4s
+fabs v15.4s, v7.4s
+subs x5, x5, #1
+beq Tile8End
+
+LoopSz_8:
+ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x7], #64
+ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x8
+
+// absmax = fmax(absmax, abs(x))
+Abs v16, v17, v18, v19
+Abs v20, v21, v22, v23
+Max v8, v9, v10, v11, v16, v17, v18, v19
+Max v12, v13, v14, v15, v20, v21, v22, v23
+
+subs x5, x5, #1
+bne LoopSz_8
+
+Tile8End:
+
+// [v0 - v7] --transpose--> [v16, v23], tmp:[v24-31]
+TRANSPOSE_8x8 v8,   v9, v10, v11, v12, v13, v14, v15, \
+              v16, v17, v18, v19, v20, v21, v22, v23, \
+              v24, v25, v26, v27, v28, v29, v30, v31
+ReduceMax v2, v3, v16, v17, v18, v19, v20, v21, v22, v23
+st1 {v2.4s, v3.4s}, [x1], #32
+sub x3, x3, #8
+add x0, x0, #128 // src += 8 * 4 * 4
+b TILE_8
+
+
+TILE_1:
+cmp x3, #1
+blt End
+mov x5, x2  // src_depth_quad
+mov x7, x0  // src
+
+//    sum: v0
+// absmax: v8
+ld1 {v0.4s}, [x7], x6
+fabs v8.4s, v0.4s
+subs x5, x5, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v16.4s}, [x7], x6
+
+// absmax = fmax(absmax, abs(x))
+fabs v16.4s, v16.4s
+fmax v8.4s, v8.4s, v16.4s
+
+subs x5, x5, #1
+bne LoopSz_1
+
+Tile1End:
+// reduce max
+mov v1.d[0], v8.d[1]
+fmax v8.4s, v8.4s, v1.4s
+mov v5.s[0], v8.s[1]
+fmax v8.4s, v5.4s, v8.4s
+st1 {v8.s}[0], [x1], #4
+subs x3, x3, #1
+add x0, x0, #16 // src += 1 * 4(pack) * 4(sizeof(float32_t))
+bne TILE_1
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif

+ 155 - 0
source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S

@@ -0,0 +1,155 @@
+//
+//  MNNDynamicQuantFP32.S
+//  MNN
+//
+//  Created by MNN on 2023/10/31.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+.macro Round z0, z1, z2, z3
+    fcvtas \z0\().4s, \z0\().4s
+    fcvtas \z1\().4s, \z1\().4s
+    fcvtas \z2\().4s, \z2\().4s
+    fcvtas \z3\().4s, \z3\().4s
+.endm
+
+.macro Transpose z0, z1, z2, z3, t0, t1, t2, t3
+    trn1 \t0\().4s, \z0\().4s, \z1\().4s
+    trn1 \t1\().4s, \z2\().4s, \z3\().4s
+    trn2 \t2\().4s, \z0\().4s, \z1\().4s
+    trn2 \t3\().4s, \z2\().4s, \z3\().4s
+
+    trn1 \z0\().2d, \t0\().2d, \t1\().2d
+    trn1 \z1\().2d, \t2\().2d, \t3\().2d
+    trn2 \z2\().2d, \t0\().2d, \t1\().2d
+    trn2 \z3\().2d, \t2\().2d, \t3\().2d
+.endm
+
+.macro Add d0, d1, d2, d3
+    add \d0\().4s, \d1\().4s, \d0\().4s
+    add \d2\().4s, \d3\().4s, \d2\().4s
+    add \d0\().4s, \d0\().4s, \d2\().4s
+.endm
+
+//void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize)
+asm_function MNNDynamicQuantFP32
+
+// x0: src, x1:dst, x2:scale, x3: sum, x4:src_depth_quad, x5:realSize
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x6, x5, #2  // dst_step = batch * unit * sizeof(int8_t) = batch * 4 = batch << 2
+lsl x7, x6, #2  // src_step = dst_step * 4 (sizeof(float32_t)) = dst_step << 2
+
+TILE_4:
+cmp x5, #4
+blt TILE_1
+mov x9, x0   // src
+mov x10, x1  // dst
+//mov x11, x2  // scale
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8, 4(batch)*sizeof(float32_t)
+ld1 {v8.4s}, [x2], #16
+
+// int8 sum
+movi v10.4s, #0
+
+LoopSz_4:
+ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x9], x7
+
+// float16_t x = x * quant_scale
+fmul v0.4s, v0.4s, v8.s[0]
+fmul v1.4s, v1.4s, v8.s[1]
+fmul v2.4s, v2.4s, v8.s[2]
+fmul v3.4s, v3.4s, v8.s[3]
+
+// int32_t x = round(x)
+Round v0, v1, v2, v3
+
+// y = (int8_t)x
+sqxtn v4.4h, v0.4s
+sqxtn2 v4.8h, v1.4s
+sqxtn v5.4h, v2.4s
+sqxtn2 v5.8h, v3.4s
+
+sqxtn v6.8b, v4.8h
+sqxtn2 v6.16b, v5.8h
+
+st1 {v6.16b}, [x10], x6
+// sum
+Transpose v0, v1, v2, v3, v14, v15, v16, v17
+Add v0, v1, v2, v3
+add v10.4s, v0.4s, v10.4s
+
+subs x12, x12, #1
+bne LoopSz_4
+
+Tile4End:
+sub x5, x5, #4    // batch -= 4
+add x0, x0, #64  // src += 4 * 4 * sizeof(float32_t)
+add x1, x1, #16   // dst += 4 * 4 * sizeof(int8_t)
+//add x2, x2, #16   // scale += 4 * sizeof(float32_t)
+st1 {v10.4s}, [x3], #16
+b TILE_4
+
+TILE_1:
+cmp x5, #1
+blt End
+mov x9, x0   // src
+mov x10, x1  // dst
+mov x12, x4  // src_depth_quad
+
+// quant_scale: v8
+ld1 {v8.s}[0], [x2], #4
+movi v4.4s, #0
+LoopSz_1:
+ld1 {v0.4s}, [x9], x7
+
+// float16_t x = x * quant_scale
+fmul v0.4s, v0.4s, v8.s[0]
+// int16_t x = round(x)
+fcvtas v0.4s, v0.4s
+
+dup v1.4s, v0.s[1]
+dup v2.4s, v0.s[2]
+dup v3.4s, v0.s[3]
+
+// y = (int8_t)x
+sqxtn v7.4h, v0.4s
+sqxtn v7.8b, v7.8h
+// sum
+
+Add v0, v1, v2, v3
+add v4.4s, v0.4s, v4.4s
+
+st1 {v7.s}[0], [x10], x6
+
+subs x12, x12, #1
+bne LoopSz_1
+
+st1 {v4.s}[0], [x3], #4
+Tile1End:
+subs x5, x5, #1    // batch -= 1
+add x0, x0, #16    // src += 1 * 4 * sizeof(float32_t)
+add x1, x1, #4    // dst += 1 * 4 * sizeof(int8_t)
+//add x2, x2, #4    // scale += 1 * sizeof(float32_t)
+bne TILE_1
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif

+ 232 - 0
source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_sdot.S

@@ -0,0 +1,232 @@
+//
+//  MNNGemmHybridInt4_sdot.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s
+    fmul \d0\().4s, \d0\().4s, \s\().s[0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[1]
+    fmul \d2\().4s, \d2\().4s, \s\().s[2]
+    fmul \d3\().4s, \d3\().4s, \s\().s[3]
+.endm
+
+.macro Dequant c0, a0, z0, b0, s0, idx
+    fmul \c0\().4s, \c0\().4s, \a0\().4s
+    fmla \c0\().4s, \z0\().4s, \s0\().s[\idx]
+    fadd \c0\().4s, \c0\().4s, \b0\().4s
+.endm
+
+asm_function MNNGemmHybridInt4FP32_sdot
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #3 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int8) = src_depth_quad * 8  = src_depth_quad << 3
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_4:
+    // src    : 4(batch) x [1 x 4] : v4
+    // weight : 4(oc) x [1 x 4] : v0
+    // dst    : 4 x 4 x [1] : v16-v19
+    ld1 {v0.8b}, [x25], #8    // weight
+    ld1 {v4.16b}, [x24], x15   // src
+    // int4->int8
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
+    .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
+    .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2
+    .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    // Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v5.4s}, [x23]  // scales, 4 batch,so 4 scale
+
+    MulScale v16, v17, v18, v19, v5
+
+Tile4Dequant:
+    ld1 {v0.4s}, [x19], #16  // alpha
+    ld1 {v1.4s}, [x20], #16  // zero
+    ld1 {v2.4s}, [x21], #16  // bias
+    ld1 {v3.4s}, [x22]  // sums
+    // alpha * sum + (zero * sums) + bias
+    Dequant v16, v0, v1, v2, v3, 0
+    Dequant v17, v0, v1, v2, v3, 1
+    Dequant v18, v0, v1, v2, v3, 2
+    Dequant v19, v0, v1, v2, v3, 3
+    st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 4 * sizeof(float32_t)
+    add x1, x1, #16     // src += 4 * 4 * sizeof(int8_t)
+    add x11, x11, #16    // sum += 4 * sizeof(float32_t)
+    add x12, x12, #16    // scale += 4 * sizeof(float32_t)
+    b TILE_4
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t)
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_1:
+    // src    : 1(batch) x [1 x 4] : v4
+    // weight : 4(oc) x [1 x 4] : v0
+    // dst    : 1 x 4 x [1] : v16
+    ld1 {v0.8b}, [x25], #8    // weight pack*pack*0.5
+    ld1 {v4.s}[0], [x24], x15   // src
+    // int4->int8
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    
+    .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    scvtf v16.4s, v16.4s
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    fmul v16.4s, v16.4s, v4.s[0]
+Tile1Dequant:
+    ld1 {v0.4s}, [x19], #16  // alpha
+    ld1 {v1.4s}, [x20], #16  // zero
+    ld1 {v2.4s}, [x21], #16  // bias
+    ld1 {v3.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    fmla v2.4s, v0.4s, v16.4s
+    fmla v2.4s, v1.4s, v3.s[0]
+    st1 {v2.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    subs x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 4 * sizeof(float32_t)
+    add x1, x1, #4      // src += 1 * 4 * sizeof(int8_t)
+    add x11, x11, #4   // sum += 1 * sizeof(float32_t)
+    add x12, x12, #4   // scale += 1 * sizeof(float32_t)
+    bne TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 373 - 0
source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_smmla.S

@@ -0,0 +1,373 @@
+//
+//  MNNGemmHybridInt4FP32_smmla.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s, idx0, idx1
+    fmul \d0\().4s, \d0\().4s, \s\().s[\idx0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[\idx0]
+    fmul \d2\().4s, \d2\().4s, \s\().s[\idx1]
+    fmul \d3\().4s, \d3\().4s, \s\().s[\idx1]
+.endm
+
+.macro Dequant c0, a0, z0, b0, s0, idx
+    fmul \c0\().4s, \c0\().4s, \a0\().4s
+    fmla \c0\().4s, \z0\().4s, \s0\().s[\idx]
+    fadd \c0\().4s, \c0\().4s, \b0\().4s
+.endm
+
+asm_function MNNGemmHybridInt4FP32_smmla
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32  = src_depth_quad << 5
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_2
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    sub x14, x14, #64
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v10.16b, w27
+    // offset
+    mov w27, #8
+    dup v11.16b, w27
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v4.16b, v0.16b, #4
+    and v5.16b, v0.16b, v10.16b
+    sub v4.16b, v4.16b, v11.16b
+    sub v5.16b, v5.16b, v11.16b
+    ushr v6.16b, v1.16b, #4
+    and v7.16b, v1.16b, v10.16b
+    sub v6.16b, v6.16b, v11.16b
+    sub v7.16b, v7.16b, v11.16b
+    zip1 v0.16b, v4.16b, v5.16b
+    zip2 v1.16b, v4.16b, v5.16b
+    zip1 v2.16b, v6.16b, v7.16b
+    zip2 v3.16b, v6.16b, v7.16b
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+
+    trn1 v24.2d, v16.2d, v17.2d // batch:0 oc:0-3
+    trn1 v25.2d, v18.2d, v19.2d // batch:0 oc:4-7
+    trn2 v26.2d, v16.2d, v17.2d // batch:1 oc:0-3
+    trn2 v27.2d, v18.2d, v19.2d // batch:1 oc:4-7
+    trn1 v28.2d, v20.2d, v21.2d // batch:2 oc:0-3
+    trn1 v29.2d, v22.2d, v23.2d // batch:2 oc:4-7
+    trn2 v30.2d, v20.2d, v21.2d // batch:3 oc:0-3
+    trn2 v31.2d, v22.2d, v23.2d // batch:3 oc:4-7
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    ld1 {v5.4s}, [x23]  // scales
+    MulScale v24, v25, v26, v27, v5, 0, 1
+    MulScale v28, v29, v30, v31, v5, 2, 3
+Tile4Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v8.4s, v9.4s}, [x21], #32  // bias
+    ld1 {v6.4s}, [x22]  // sums
+    // alpha * cusum + (zero * sums) + bias
+    Dequant v24, v0, v2, v8, v6, 0 // Batch0
+    Dequant v25, v1, v3, v9, v6, 0
+    Dequant v26, v0, v2, v8, v6, 1 // Batch1
+    Dequant v27, v1, v3, v9, v6, 1
+    Dequant v28, v0, v2, v8, v6, 2 // Batch2
+    Dequant v29, v1, v3, v9, v6, 2
+    Dequant v30, v0, v2, v8, v6, 3 // Batch3
+    Dequant v31, v1, v3, v9, v6, 3
+    st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x17], #64
+    st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #128     // dst += 4 * 8 * sizeof(float32_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #16    // sum += 4 * sizeof(float32_t)
+    add x12, x12, #16   // scale += 4 * sizeof(float32_t)
+    b TILE_4
+
+TILE_2:
+    cmp x6, #2
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_2:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+LoopSz_TILE_2:
+    // src    : 1 x [2 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [4] : v16-19
+    //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    ushr v10.16b, v1.16b, #4
+    and v11.16b, v1.16b, v14.16b
+    sub v10.16b, v10.16b, v15.16b
+    sub v11.16b, v11.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    zip2 v1.16b, v8.16b, v9.16b
+    zip1 v2.16b, v10.16b, v11.16b
+    zip2 v3.16b, v10.16b, v11.16b
+    ld1 {v4.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_2
+
+LoopSzEnd_TILE_2:
+    add x18, x18, x13
+    sub x16, x16, #1
+    trn1 v20.2d, v16.2d, v17.2d
+    trn1 v21.2d, v18.2d, v19.2d
+    trn2 v22.2d, v16.2d, v17.2d
+    trn2 v23.2d, v18.2d, v19.2d
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v5.d}[0], [x23]  // scales
+    fmul v20.4s, v20.4s, v5.s[0]
+    fmul v21.4s, v21.4s, v5.s[0]
+    fmul v22.4s, v22.4s, v5.s[1]
+    fmul v23.4s, v23.4s, v5.s[1]
+Tile2Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v8.4s, v9.4s}, [x21], #32  // bias
+    ld1 {v10.d}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v20, v0, v2, v8, v10, 0
+    Dequant v21, v1, v3, v9, v10, 0
+    Dequant v22, v0, v2, v8, v10, 1
+    Dequant v23, v1, v3, v9, v10, 1
+    st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_2
+Tile2End:
+    sub x6, x6, #2      // batch -= 2
+    add x0, x0, #64     // dst += 2 * 8 * sizeof(float32_t)
+    add x1, x1, #16     // dst += 2 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 2 * sizeof(float32_t)
+    add x12, x12, #8    // scale += 2 * sizeof(float32_t)
+    b TILE_2
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t)
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    // mask
+    mov w27, #0x0f
+    dup v14.16b, w27
+    // offset
+    mov w27, #8
+    dup v15.16b, w27
+
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v0.16b, v1.16b}, [x25], #32    // weight
+    // int4 to int8: v0, v1, v2, v3
+    ushr v8.16b, v0.16b, #4
+    and v9.16b, v0.16b, v14.16b
+    sub v8.16b, v8.16b, v15.16b
+    sub v9.16b, v9.16b, v15.16b
+    ushr v10.16b, v1.16b, #4
+    and v11.16b, v1.16b, v14.16b
+    sub v10.16b, v10.16b, v15.16b
+    sub v11.16b, v11.16b, v15.16b
+    zip1 v0.16b, v8.16b, v9.16b
+    zip2 v1.16b, v8.16b, v9.16b
+    zip1 v2.16b, v10.16b, v11.16b
+    zip2 v3.16b, v10.16b, v11.16b
+    ld1 {v4.8b}, [x24], x15   // src
+    .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
+    .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
+    .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
+    .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v20.4s, v16.4s, v17.4s
+    uzp1 v21.4s, v18.4s, v19.4s
+    scvtf v20.4s, v20.4s
+    scvtf v21.4s, v21.4s
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    fmul v20.4s, v20.4s, v4.s[0]
+    fmul v21.4s, v21.4s, v4.s[0]
+Tile1Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v12.4s, v13.4s}, [x21], #32  // bias
+    ld1 {v6.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    fmla v12.4s, v20.4s, v0.4s
+    fmla v13.4s, v21.4s, v1.4s
+    fmla v12.4s, v2.4s, v6.s[0]
+    fmla v13.4s, v3.4s, v6.s[0]
+    st1 {v12.4s, v13.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #32     // dst += 1 * 8 * sizeof(float32_t)
+    add x1, x1, #8      // dst += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #4   // sum += 1 * sizeof(float32_t)
+    add x12, x12, #4   // scale += 1 * sizeof(float32_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 209 - 0
source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_sdot.S

@@ -0,0 +1,209 @@
+//
+//  MNNGemmHybridInt8_smmla.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s
+    fmul \d0\().4s, \d0\().4s, \s\().s[0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[1]
+    fmul \d2\().4s, \d2\().4s, \s\().s[2]
+    fmul \d3\().4s, \d3\().4s, \s\().s[3]
+.endm
+
+.macro Dequant c0, a0, z0, b0, s0, idx
+    fmul \c0\().4s, \c0\().4s, \a0\().4s
+    fmla \c0\().4s, \z0\().4s, \s0\().s[\idx]
+    fadd \c0\().4s, \c0\().4s, \b0\().4s
+.endm
+
+asm_function MNNGemmHybridInt8FP32_sdot
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt8FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #4 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 16  = src_depth_quad << 4
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+
+LoopSz_TILE_4:
+    // src    : 4(batch) x [1 x 4] : v4
+    // weight : 4(oc) x [1 x 4] : v0
+    // dst    : 4 x 4 x [1] : v16-v19
+    ld1 {v0.16b}, [x25], #16    // weight
+    ld1 {v4.16b}, [x24], x15   // src
+    .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
+    .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
+    .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2
+    .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+    Int32ToFloat v16, v17, v18, v19
+    // Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v5.4s}, [x23]  // scales, 4 batch,so 4 scale
+
+    MulScale v16, v17, v18, v19, v5
+
+Tile4Dequant:
+    ld1 {v0.4s}, [x19], #16  // alpha
+    ld1 {v1.4s}, [x20], #16  // zero
+    ld1 {v2.4s}, [x21], #16  // bias
+    ld1 {v3.4s}, [x22]  // sums
+    // alpha * sum + (zero * sums) + bias
+    Dequant v16, v0, v1, v2, v3, 0
+    Dequant v17, v0, v1, v2, v3, 1
+    Dequant v18, v0, v1, v2, v3, 2
+    Dequant v19, v0, v1, v2, v3, 3
+    st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #64     // dst += 4 * 4 * sizeof(float32_t)
+    add x1, x1, #16     // src += 4 * 4 * sizeof(int8_t)
+    add x11, x11, #16    // sum += 4 * sizeof(float32_t)
+    add x12, x12, #16    // scale += 4 * sizeof(float32_t)
+    b TILE_4
+
+TILE_1:
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t)
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+
+LoopSz_TILE_1:
+    // src    : 1(batch) x [1 x 4] : v4
+    // weight : 4(oc) x [1 x 4] : v0
+    // dst    : 1 x 4 x [1] : v16
+    ld1 {v0.16b}, [x25], #16    // weight
+    ld1 {v4.s}[0], [x24], x15   // src
+    .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    scvtf v16.4s, v16.4s
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    fmul v16.4s, v16.4s, v4.s[0]
+Tile1Dequant:
+    ld1 {v0.4s}, [x19], #16  // alpha
+    ld1 {v1.4s}, [x20], #16  // zero
+    ld1 {v2.4s}, [x21], #16  // bias
+    ld1 {v3.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    fmla v2.4s, v0.4s, v16.4s
+    fmla v2.4s, v1.4s, v3.s[0]
+    st1 {v2.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #16     // dst += 1 * 4 * sizeof(float32_t)
+    add x1, x1, #4      // src += 1 * 4 * sizeof(int8_t)
+    add x11, x11, #4   // sum += 1 * sizeof(float32_t)
+    add x12, x12, #4   // scale += 1 * sizeof(float32_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 314 - 0
source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_smmla.S

@@ -0,0 +1,314 @@
+//
+//  MNNGemmHybridInt8FP32_smmla.S
+//  MNN
+//
+//  Created by MNN on 2023/11/09.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+
+.text
+.align 5
+
+.macro Int32ToFloat z0, z1, z2, z3
+    scvtf \z0\().4s, \z0\().4s
+    scvtf \z1\().4s, \z1\().4s
+    scvtf \z2\().4s, \z2\().4s
+    scvtf \z3\().4s, \z3\().4s
+.endm
+
+.macro MulScale d0, d1, d2, d3, s, idx0, idx1
+    fmul \d0\().4s, \d0\().4s, \s\().s[\idx0]
+    fmul \d1\().4s, \d1\().4s, \s\().s[\idx0]
+    fmul \d2\().4s, \d2\().4s, \s\().s[\idx1]
+    fmul \d3\().4s, \d3\().4s, \s\().s[\idx1]
+.endm
+
+.macro Dequant c0, a0, z0, b0, s0, idx
+    fmul \c0\().4s, \c0\().4s, \a0\().4s
+    fmla \c0\().4s, \z0\().4s, \s0\().s[\idx]
+    fadd \c0\().4s, \c0\().4s, \b0\().4s
+.endm
+
+asm_function MNNGemmHybridInt8FP32_smmla
+
+//struct QuanPostTreatParameters {
+//    const float* scale;
+//    const int32_t* bias;
+//    int32_t maxValue;
+//    int32_t minValue;
+//    int32_t useInt8;
+//};
+
+//void MNNGemmHybridInt8FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); 
+
+
+// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
+// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales*
+stp d14, d15, [sp, #(-16 * 9)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+stp x21, x22, [sp, #(16 * 4)]
+stp x19, x20, [sp, #(16 * 5)]
+stp x23, x24, [sp, #(16 * 6)]
+stp x25, x26, [sp, #(16 * 7)]
+stp x27, x28, [sp, #(16 * 8)]
+
+ldr x8, [x7, #0]
+ldr x9, [x7, #8]
+ldr x10, [x7, #16]
+ldr x11, [x7, #24]
+ldr x12, [x7, #32]
+
+Start:
+lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64  = src_depth_quad << 6
+
+TILE_4:
+    cmp x6, #4
+    blt TILE_2
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    sub x14, x14, #64
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_4:
+    // dequant info for batch
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+    dup v20.4s, wzr
+    dup v21.4s, wzr
+    dup v22.4s, wzr
+    dup v23.4s, wzr
+LoopSz_TILE_4:
+    // src    : 2 x [2 x 8] : v4-5
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 2 x 4 x [4] : v16-23
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b, v5.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b
+    .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b
+    .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b
+    .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_4
+
+LoopSzEnd_TILE_4:
+    add x18, x18, x13
+    sub x16, x16, #1
+
+    trn1 v24.2d, v16.2d, v17.2d // batch:0 oc:0-3
+    trn1 v25.2d, v18.2d, v19.2d // batch:0 oc:4-7
+    trn2 v26.2d, v16.2d, v17.2d // batch:1 oc:0-3
+    trn2 v27.2d, v18.2d, v19.2d // batch:1 oc:4-7
+    trn1 v28.2d, v20.2d, v21.2d // batch:2 oc:0-3
+    trn1 v29.2d, v22.2d, v23.2d // batch:2 oc:4-7
+    trn2 v30.2d, v20.2d, v21.2d // batch:3 oc:0-3
+    trn2 v31.2d, v22.2d, v23.2d // batch:3 oc:4-7
+    Int32ToFloat v24, v25, v26, v27
+    Int32ToFloat v28, v29, v30, v31
+    // using float scale dequant for precison
+    ld1 {v5.4s}, [x23]  // scales
+    MulScale v24, v25, v26, v27, v5, 0, 1
+    MulScale v28, v29, v30, v31, v5, 2, 3
+Tile4Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v8.4s, v9.4s}, [x21], #32  // bias
+    ld1 {v6.4s}, [x22]  // sums
+    // alpha * cusum + (zero * sums) + bias
+    Dequant v24, v0, v2, v8, v6, 0 // Batch0
+    Dequant v25, v1, v3, v9, v6, 0
+    Dequant v26, v0, v2, v8, v6, 1 // Batch1
+    Dequant v27, v1, v3, v9, v6, 1
+    Dequant v28, v0, v2, v8, v6, 2 // Batch2
+    Dequant v29, v1, v3, v9, v6, 2
+    Dequant v30, v0, v2, v8, v6, 3 // Batch3
+    Dequant v31, v1, v3, v9, v6, 3
+    st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x17], #64
+    st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_4
+Tile4End:
+    sub x6, x6, #4      // bach -= 4
+    add x0, x0, #128     // dst += 4 * 8 * sizeof(float32_t)
+    add x1, x1, #32     // src += 4 * 8 * sizeof(int8_t)
+    add x11, x11, #16    // sum += 4 * sizeof(float32_t)
+    add x12, x12, #16   // scale += 4 * sizeof(float32_t)
+    b TILE_4
+
+TILE_2:
+    cmp x6, #2
+    blt TILE_1
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_2:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+LoopSz_TILE_2:
+    // src    : 1 x [2 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [4] : v16-19
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.16b}, [x24], x15   // src
+    .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
+    .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
+    .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b
+    .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b
+    subs x26, x26, #1
+    bne LoopSz_TILE_2
+
+LoopSzEnd_TILE_2:
+    add x18, x18, x13
+    sub x16, x16, #1
+    trn1 v20.2d, v16.2d, v17.2d
+    trn1 v21.2d, v18.2d, v19.2d
+    trn2 v22.2d, v16.2d, v17.2d
+    trn2 v23.2d, v18.2d, v19.2d
+    Int32ToFloat v20, v21, v22, v23
+    // using float scale dequant for precison
+    ld1 {v5.d}[0], [x23]  // scales
+    fmul v20.4s, v20.4s, v5.s[0]
+    fmul v21.4s, v21.4s, v5.s[0]
+    fmul v22.4s, v22.4s, v5.s[1]
+    fmul v23.4s, v23.4s, v5.s[1]
+Tile2Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v8.4s, v9.4s}, [x21], #32  // bias
+    ld1 {v10.d}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    Dequant v20, v0, v2, v8, v10, 0
+    Dequant v21, v1, v3, v9, v10, 0
+    Dequant v22, v0, v2, v8, v10, 1
+    Dequant v23, v1, v3, v9, v10, 1
+    st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_2
+Tile2End:
+    sub x6, x6, #2      // batch -= 2
+    add x0, x0, #64     // dst += 2 * 8 * sizeof(float32_t)
+    add x1, x1, #16     // dst += 2 * 8 * sizeof(int8_t)
+    add x11, x11, #8    // sum += 2 * sizeof(float32_t)
+    add x12, x12, #8    // scale += 2 * sizeof(float32_t)
+    b TILE_2
+
+TILE_1:
+    
+    cmp x6, #1
+    blt End
+    mov x14, x4       // dst_step
+    lsr x15, x4, #2   // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t)
+    mov x16, x5 // dst_depth_quad
+    mov x17, x0 // dst
+    mov x18, x2 // weight
+    // dequant info
+    mov x19, x8 // alpha
+    mov x20, x9 // zero
+    mov x21, x10 // bias
+LoopDz_TILE_1:
+    mov x22, x11 // sums
+    mov x23, x12 // scales
+    mov x24, x1  // src
+    mov x25, x18 // weight
+    mov x26, x3  // src_depth_quad
+    // init
+    dup v16.4s, wzr
+    dup v17.4s, wzr
+    dup v18.4s, wzr
+    dup v19.4s, wzr
+
+LoopSz_TILE_1:
+    // src    : 1 x [1 x 8] : v4
+    // weight : 4 x [2 x 8] : v0-3
+    // dst    : 1 x 4 x [2] : v16-v19
+    ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64    // weight
+    ld1 {v4.8b}, [x24], x15   // src
+    .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
+    .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
+    .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
+    .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b
+
+    subs x26, x26, #1
+    bne LoopSz_TILE_1
+
+LoopSzEnd_TILE_1:
+    add x18, x18, x13
+    sub x16, x16, #1
+    uzp1 v20.4s, v16.4s, v17.4s
+    uzp1 v21.4s, v18.4s, v19.4s
+    scvtf v20.4s, v20.4s
+    scvtf v21.4s, v21.4s
+    // using float scale dequant for precison
+    ld1 {v4.s}[0], [x23]  // scales
+    fmul v20.4s, v20.4s, v4.s[0]
+    fmul v21.4s, v21.4s, v4.s[0]
+Tile1Dequant:
+    ld1 {v0.4s, v1.4s}, [x19], #32  // alpha
+    ld1 {v2.4s, v3.4s}, [x20], #32  // zero
+    ld1 {v10.4s, v11.4s}, [x21], #32  // bias
+    ld1 {v8.s}[0], [x22]  // sums
+    // alpha * sum + (zero * sumx) + bias
+    fmla v10.4s, v20.4s, v0.4s
+    fmla v11.4s, v21.4s, v1.4s
+    fmla v10.4s, v2.4s, v8.s[0]
+    fmla v11.4s, v3.4s, v8.s[0]
+    st1 {v10.4s, v11.4s}, [x17], x14
+    cmp x16, #1
+    bge LoopDz_TILE_1
+Tile1End:
+    sub x6, x6, #1      // batch -= 1
+    add x0, x0, #32     // dst += 1 * 8 * sizeof(float32_t)
+    add x1, x1, #8      // src += 1 * 8 * sizeof(int8_t)
+    add x11, x11, #4   // sum += 1 * sizeof(float32_t)
+    add x12, x12, #4   // scale += 1 * sizeof(float32_t)
+    b TILE_1
+
+End:
+ldp x27, x28, [sp, #(16 * 8)]
+ldp x25, x26, [sp, #(16 * 7)]
+ldp x23, x24, [sp, #(16 * 6)]
+ldp x19, x20, [sp, #(16 * 5)]
+ldp x21, x22, [sp, #(16 * 4)]
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 9)
+ret
+
+#endif

+ 102 - 0
source/backend/cpu/arm/arm64/low_memory/MNNQuantScaleFP32.S

@@ -0,0 +1,102 @@
+//
+//  MNNQuantScaleFP32.S
+//  MNN
+//
+//  Created by MNN on 2023/11/01.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+// void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch)
+asm_function MNNQuantScaleFP32
+
+// x0:absmax, x1:quant_scale, x2:dequant_scale, x3:thread, x4:batch
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+movi v31.4s, #127
+scvtf v31.4s, v31.4s
+lsl x9, x4, #2 // src_step = batch * sizeof(float32_t)
+
+TILE_4:
+cmp x4, #4
+blt TILE_1
+mov x7, x0  // max_ptr
+mov x8, x3  // thread
+
+// absmax: v1
+ld1 {v1.4s}, [x7], x9
+subs x8, x8, #1
+beq Tile4End
+
+LoopSz_4:
+ld1 {v3.4s}, [x7], x9
+
+// absmax = fmax(absmax, absmax[i])
+fmax v1.4s, v1.4s, v3.4s
+
+subs x8, x8, #1
+bne LoopSz_4
+
+Tile4End:
+sub x4, x4, #4
+add x0, x0, #16
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv v2.4s, v31.4s, v1.4s
+fdiv v3.4s, v1.4s, v31.4s
+st1 {v2.4s}, [x1], #16
+st1 {v3.4s}, [x2], #16
+b TILE_4
+
+
+TILE_1:
+cmp x4, #1
+blt End
+mov x7, x0  // max_ptr
+mov x8, x3  // thread
+
+//    sum: v0
+// absmax: v1
+ld1 {v1.s}[0], [x7], x9
+subs x8, x8, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v3.s}[0], [x7], x9
+
+// absmax = fmax(absmax, absmax[i])
+fmax s1, s1, s3
+
+subs x8, x8, #1
+bne LoopSz_1
+
+Tile1End:
+sub x4, x4, #1
+add x0, x0, #4
+// quant_scale = 127 / absmax
+// dequant_scale = absmax / 127
+fdiv s2, s31, s1
+fdiv s3, s1, s31
+st1 {v2.s}[0], [x1], #4
+st1 {v3.s}[0], [x2], #4
+b TILE_1
+
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif
+

+ 100 - 0
source/backend/cpu/arm/arm64/low_memory/MNNQuantSumFP32.S

@@ -0,0 +1,100 @@
+//
+//  MNNQuantSumFP32.S
+//  MNN
+//
+//  Created by MNN on 2023/11/30.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef __aarch64__
+
+#include "MNNAsmGlobal.h"
+.text
+.align 5
+
+//void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch)
+asm_function MNNQuantSumFP32
+
+// x0: sum, x1:dequant_scale, x2:thread, x3:batch
+stp d14, d15, [sp, #(-16 * 4)]!
+stp d12, d13, [sp, #(16 * 1)]
+stp d10, d11, [sp, #(16 * 2)]
+stp d8,  d9,  [sp, #(16 * 3)]
+
+Start:
+lsl x9, x3, #2 // src_step = batch * sizeof(int32_t)
+
+TILE_4:
+cmp x3, #4
+blt TILE_1
+// add x6, x0, x10  // sum_ptr
+mov x6, x0
+mov x7, x2  // thread
+
+// sum: v0
+ld1 {v0.4s}, [x6], x9
+subs x7, x7, #1
+beq Tile4End
+
+LoopSz_4:
+ld1 {v1.4s}, [x6], x9
+
+// sum += sum[i]
+add v0.4s, v0.4s, v1.4s
+
+subs x7, x7, #1
+bne LoopSz_4
+
+Tile4End:
+sub x3, x3, #4
+// load dequant_scale
+ld1 {v2.4s}, [x1], #16
+// sum_half = (float)sum_int * dequant_scale
+scvtf v3.4s, v0.4s
+fmul v4.4s, v3.4s, v2.4s
+st1 {v4.4s}, [x0], #16
+b TILE_4
+
+// x0: sum, x1:dequant_scale, x2:thread, x3:batch
+TILE_1:
+cmp x3, #1
+blt End
+mov x6, x0
+mov x7, x2  // thread
+
+// sum: v0
+ld1 {v0.s}[0], [x6], x9
+subs x7, x7, #1
+beq Tile1End
+
+LoopSz_1:
+ld1 {v1.s}[0], [x6], x9
+
+// sum += sum[i]
+// add s0, s0, s1
+add v0.4s, v0.4s, v1.4s
+
+subs x7, x7, #1
+bne LoopSz_1
+
+Tile1End:
+sub x3, x3, #1
+// load dequant_scale
+ld1 {v2.s}[0], [x1], #4
+
+// sum_half = (float)sum_int * dequant_scale
+scvtf v3.4s, v0.4s
+fmul s4, s3, s2
+st1 {v4.s}[0], [x0], #4 
+b TILE_1
+
+
+End:
+ldp d8,  d9,  [sp, #(16 * 3)]
+ldp d10, d11, [sp, #(16 * 2)]
+ldp d12, d13, [sp, #(16 * 1)]
+ldp d14, d15, [sp], #(16 * 4)
+ret
+
+#endif
+

+ 6 - 6
source/backend/cpu/bf16/BF16Binary.cpp

@@ -203,37 +203,37 @@ void BF16Binary(void *dstRaw, const void *src0Raw, const void *src1Raw, const in
 }
 
 
-struct VecBinaryAdd : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinaryAdd {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return x + y;
     }
 };
 
-struct VecBinarySub : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinarySub {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return x - y;
     }
 };
 
-struct VecBinaryMul : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinaryMul {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return x * y;
     }
 };
 
-struct VecBinaryMin : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinaryMin {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return Vec4Half::min(x, y);
     }
 };
 
-struct VecBinaryMax : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinaryMax {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return Vec4Half::max(x, y);
     }
 };
 
-struct VecBinarySqd : std::binary_function<Vec4Half, Vec4Half, Vec4Half> {
+struct VecBinarySqd {
     Vec4Half operator()(const Vec4Half& x, const Vec4Half& y) const {
         return (x-y)*(x-y);
     }

+ 212 - 74
source/backend/cpu/compute/CommonOptFunction.cpp

@@ -23,6 +23,11 @@
 #include "../CPUBinary.hpp"
 #include "../CPUUnary.hpp"
 #include "../CPUPool.hpp"
+#define PACK 4
+#define FLOAT float
+using Vec = MNN::Math::Vec<float, 4>;
+#include "../GridSampler.hpp"
+
 #ifndef MNN_USE_SSE
 void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {
     // Should not be called
@@ -142,7 +147,7 @@ void MNNPackC2Common(T* dst, const T* src, size_t area, size_t depth, int* areaO
 }
 
 template<typename T>
-void MNNUnpackC2Common(T* dst, const T* src, size_t area, size_t depth, int* areaOffset) {
+void MNNUnpackC2Common(T* dst, const T* src, size_t area, size_t depth, int* areaOffset, int pack = 1) {
     int depthC2     = depth / 2;
     int depthRemain = depthC2 * 2;
     int remain      = depth - depthRemain;
@@ -151,24 +156,28 @@ void MNNUnpackC2Common(T* dst, const T* src, size_t area, size_t depth, int* are
     const T* srcOffset = src;
     for(z = 0; z < depthC2; ++z) {
         for(y = 0; y < 2; ++y) {
-            auto dstZ = dst + (z * 2 + y) * areaOffset[1];
-            srcChannel[y] = srcOffset + y;
+            auto dstZ = dst + (z * 2 + y) * areaOffset[1] * pack;
+            srcChannel[y] = srcOffset + y * pack;
             for(x = 0; x < area; ++x) {
-                dstZ[x] = srcChannel[y][0];
-                srcChannel[y] += 2;
+                for (int p = 0; p < pack; ++p) {
+                    dstZ[x * pack + p] = srcChannel[y][p];
+                }
+                srcChannel[y] += (2 * pack);
             }
         }
-        srcOffset += areaOffset[0] * 2;
+        srcOffset += areaOffset[0] * 2 * pack;
     }
     if(remain > 0){
-        auto dstZ = dst + depthC2 * areaOffset[1] * 2;
+        auto dstZ = dst + depthC2 * areaOffset[1] * 2 * pack;
         for(y = 0; y < remain; ++y) {
-            srcChannel[y] = srcOffset + y;
+            srcChannel[y] = srcOffset + y * pack;
             for(x = 0; x < area; ++x) {
-                dstZ[x] = srcChannel[y][0];
-                srcChannel[y] += 2;
+                for (int p = 0; p < pack; ++p) {
+                    dstZ[x * pack + p] = srcChannel[y][p];
+                }
+                srcChannel[y] += 2 * pack;
             }
-            dstZ += areaOffset[1];
+            dstZ += areaOffset[1] * pack;
         }
     }
 }
@@ -430,7 +439,6 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
 }
 
 #ifndef MNN_USE_NEON
-
 void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
     *eP = 16;
     *lP = 1;
@@ -694,6 +702,161 @@ void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t
     auto aStride = parameter[0] / sizeof(float);
     _MNNPackedMatMulRemain_int8(C, A, B, eSize, parameter, postParameters, bias, aStride, k, b);
 }
+void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
+    // source: (ic/4, N, 4)
+    auto srcStep = pack * realSize;
+    for (int i = 0; i < realSize; ++i) {
+        float absmaxVal = 0.f; // absmaxVal>=0
+        for (int c = 0; c < src_depth_quad; ++c) {
+            auto src = source + c * srcStep + i * pack;
+            for (int k = 0; k < pack; ++k) {
+                absmaxVal = std::max(absmaxVal, std::abs(src[k]));
+            }
+        }
+        absmax[i] = absmaxVal;
+    }
+}
+void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) {
+    for (int i = 0; i < batch; ++i) {
+        auto absmaxPtr = absmax + i;
+        float absVal = 0.f;
+        for (int t = 0; t < thread; ++t) {
+            absVal = std::max(absVal, absmaxPtr[t * batch]);
+        }
+        quant_scale[i] = 127.0f / absVal;
+        dequant_scale[i] = absVal / 127.0f;
+    }
+}
+void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch) {
+    for (int i = 0; i < batch; ++i) {
+        auto sumPtr = reinterpret_cast<int*>(sum) + i;
+        int sumVal = 0.f;
+        for (int t = 0; t < thread; ++t) {
+            sumVal += sumPtr[t * batch];
+        }
+        sum[i] = sumVal * dequant_scale[i];
+    }
+}
+void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack) {
+#ifdef MNN_USE_SSE
+    uint8_t* dstPtr = reinterpret_cast<uint8_t*>(dst);
+#else
+    int8_t* dstPtr = dst;
+#endif
+    for (int i = 0; i < realSize; ++i) {
+        auto scaleVal = scale[i];
+        int acc = 0;
+        for (int c = 0; c < src_depth_quad; ++c) {
+            auto srcZ = src + c * pack * realSize + i * pack;
+            auto dstZ = dstPtr + c * pack * realSize + i * pack;
+            for (int k = 0; k < pack; ++k) {
+                int val = (int)roundf(srcZ[k] * scaleVal);
+                acc += val;
+                dstZ[k] = val;
+            }
+        }
+        ((int32_t*)sum)[i] = acc;
+    }
+}
+void MNNGemmHybridInt8FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) {
+    // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4)
+    int pack = 4;
+    size_t weight_step = src_depth_quad * pack * pack;
+    const float* alpha_ptr = param[0];
+    const float* zero_ptr = param[1];
+    const float* bias_ptr = param[2];
+    const float* sums_ptr = param[3];
+    const float* scale_ptr = param[4];
+    for (int ci = 0; ci < dst_depth_quad; ++ci) {
+        float* dstZ = C + ci * pack * realSize;
+        const int8_t*    weight = B + ci * weight_step;
+        auto alpha = alpha_ptr + ci * pack;
+        auto zero  = zero_ptr + ci * pack;
+        auto bias  = bias_ptr + ci * pack;
+        //const float* sums = param[2];
+        for (int j = 0; j < realSize; ++j) {
+            const float* sums = sums_ptr + j;
+            const float* scale = scale_ptr + j;
+            float* dstX = dstZ + j * pack;
+            std::vector<int> tmp(pack);
+            // int8_t* weightPtr = B + weight_step;
+            const int8_t* srcBatch = A + j * pack;
+            for (int k = 0; k < src_depth_quad; ++k) {
+                const int8_t* srcZ = srcBatch + k * pack * realSize;
+                const int8_t* weightZ = weight + k * pack * pack;
+                for (int cn = 0; cn < pack; ++cn) { // pack for oc
+                    const auto weightj = weightZ + cn * pack;
+                    for (int ck = 0; ck < pack; ++ck) { // pack for ic
+                        tmp[cn] += (int32_t)srcZ[ck] * (int32_t)weightj[ck];
+                    }
+                }
+            }
+            
+            // int32->float
+            for (int cn = 0; cn < pack; ++cn) {
+                float val = (float)tmp[cn] * scale[0];
+                val = bias[cn] + val * alpha[cn] + zero[cn] * sums[0];
+                dstX[cn] = val;
+            }
+        }
+    }
+}
+void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) {
+    // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4)
+    int pack = 4;
+    size_t weight_step = src_depth_quad * pack * pack * 0.5;
+    size_t weight_stride = pack * pack / 2;
+    const float* alpha_ptr = param[0];
+    const float* zero_ptr = param[1];
+    const float* bias_ptr = param[2];
+    const float* sums_ptr = param[3];
+    const float* scale_ptr = param[4];
+    for (int ci = 0; ci < dst_depth_quad; ++ci) {
+        float* dstZ = C + ci * pack * realSize;
+        const int8_t*    weight = B + ci * weight_step;
+        auto alpha = alpha_ptr + ci * pack;
+        auto zero  = zero_ptr + ci * pack;
+        auto bias  = bias_ptr + ci * pack;
+        //const float* sums = param[2];
+        for (int j = 0; j < realSize; ++j) {
+            const float* sums = sums_ptr + j;
+            const float* scale = scale_ptr + j;
+            float* dstX = dstZ + j * pack;
+            int tmp[4] = {0, 0, 0, 0};
+            // int8_t* weightPtr = B + weight_step;
+            const int8_t* srcBatch = A + j * pack;
+            for (int k = 0; k < src_depth_quad; ++k) {
+                const int8_t* srcZ = srcBatch + k * pack * realSize;
+                const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride;
+                int32_t tmpw[16];
+                uint32_t c = 0xf;
+                for (int kk = 0; kk < 8; ++kk) {
+                    tmpw[2 * kk] = (weightZ[kk]>>4) - 8;
+                    tmpw[2 * kk + 1] = (weightZ[kk] & c) - 8;
+                }
+                for (int cn = 0; cn < pack; ++cn) { // pack for oc
+                    const auto weightj = tmpw + cn * pack;
+                    for (int ck = 0; ck < pack; ++ck) { // pack for ic
+                        tmp[cn] += (int32_t)srcZ[ck] * (int32_t)weightj[ck];
+                    }
+                }
+            }
+            
+            // int32->float
+            for (int cn = 0; cn < pack; ++cn) {
+                float val = (float)tmp[cn] * scale[0];
+                val = bias[cn] + val * alpha[cn] + zero[cn] * sums[0];
+                dstX[cn] = val;
+            }
+        }
+    }
+}
+void MNNGemmHybridInt8FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) {
+    MNNGemmHybridInt8FP32_smmla(C, A, B, src_depth_quad, dst_step, dst_depth_quad, realSize, param);
+}
+void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) {
+    MNNGemmHybridInt4FP32_smmla(C, A, B, src_depth_quad, dst_step, dst_depth_quad, realSize, param);
+}
 #endif
 
 void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
@@ -1710,6 +1873,21 @@ void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) {
         }
     }
 }
+void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) {
+    int w = dim[0];
+    int h = dim[1];
+    int srcStride = dim[2];
+    int dstStride = dim[3];
+    for (int i=0; i<h; ++i) {
+        auto si = srcO + i;
+        auto di = dstO + i * dstStride;
+        for (int j=0; j<w; ++j) {
+            auto sj = si + j * srcStride;
+            auto dj = di + j;
+            *dj = *sj;
+        }
+    }
+}
 #endif
 void MNNFunctionInit() {
     // Do nothing
@@ -1967,22 +2145,7 @@ void MNNNorm(float *dst, const float *src, const float *gamma, const float *beta
 }
 #endif
 
-size_t MNNGridSampleComputeOffset(int h, int w, int height, int width, bool padMode) {
-    if (padMode == true) { //padMode == BorderMode_ZEROS
-        if (h < 0 || h >= height || w < 0 || w >= width) {
-            return -1;
-        }
-    } else {
-        // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
-        // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
-        // the leftover reflections degrade to GridSamplePaddingMode_BORDER
-        h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h);
-        w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w);
-    }
-    return h * width * 4 + w * 4;
-}
-
-size_t MNNGridSampleComputeOffset3D(int d, int h, int w, int depth, int height, int width, bool padMode) {
+int MNNGridSampleComputeOffset3D(int d, int h, int w, int depth, int height, int width, bool padMode) {
     if (padMode == true) { //padMode == BorderMode_ZEROS
         if (h < 0 || h >= height || w < 0 || w >= width || d < 0 || d >= depth) {
             return -1;
@@ -1998,52 +2161,6 @@ size_t MNNGridSampleComputeOffset3D(int d, int h, int w, int depth, int height,
     return ((d * height + h) * width + w) * 4;
 }
 
-void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) {
-    for (auto ow = 0; ow < outW; ++ow) {
-        auto w = cordPtr[2 * ow + 0];
-        auto h = cordPtr[2 * ow + 1];
-        Vec4 interp;
-
-        if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
-            int nh = ::floor(h + 0.5f);
-            int nw = ::floor(w + 0.5f);
-            size_t ns = MNNGridSampleComputeOffset(nh, nw, inH, inW, padMode);
-            for (int k = 0; k < channelCUnit; ++k) {
-                interp = ns == -1 ? Vec4(0.f) : Vec4::load(inputPtr + k * inOffset + ns);
-                Vec4::save(outputPtr + k * outOffset + 4 * ow, interp);
-            }
-        } else { //sampleMode == GridSampleMode_BILINEAR
-            int w0_h = ::floor(h);
-            int w0_w = ::floor(w);
-            int w1_h = ::ceil(h);
-            int w1_w = ::ceil(w);
-            auto oneV = Vec4(1.0f);
-
-            auto f0 = Vec4((float)w1_w - w);
-            auto f1 = oneV - f0;
-            auto h0 = Vec4((float)w1_h - h);
-            auto h1 = oneV - h0;
-
-            size_t s00 = MNNGridSampleComputeOffset(w0_h, w0_w, inH, inW, padMode);
-            size_t s01 = MNNGridSampleComputeOffset(w0_h, w1_w, inH, inW, padMode);
-            size_t s10 = MNNGridSampleComputeOffset(w1_h, w0_w, inH, inW, padMode);
-            size_t s11 = MNNGridSampleComputeOffset(w1_h, w1_w, inH, inW, padMode);
-
-            for (int k = 0; k < channelCUnit; ++k) {
-                Vec4 i00 = s00 == -1 ? Vec4(0.f) : Vec4::load(inputPtr + k * inOffset + s00);
-                Vec4 i01 = s01 == -1 ? Vec4(0.f) : Vec4::load(inputPtr + k * inOffset + s01);
-                Vec4 i10 = s10 == -1 ? Vec4(0.f) : Vec4::load(inputPtr + k * inOffset + s10);
-                Vec4 i11 = s11 == -1 ? Vec4(0.f) : Vec4::load(inputPtr + k * inOffset + s11);
-
-                Vec4 i0 = i00 * f0 + i01 * f1;
-                Vec4 i1 = i10 * f0 + i11 * f1;
-
-                interp = i0 * h0 + i1 * h1;
-                Vec4::save(outputPtr + k * outOffset + 4 * ow, interp);
-            }
-        }
-    }
-}
 
 void MNNRoiPoolingMax(float* dst, const float* src, int hLen, int wLen, int iw) {
     Vec4 max = Vec4(-FLT_MAX);
@@ -3187,6 +3304,10 @@ void MNNCoreFunctionInit() {
     gCoreFunction->MNNPackedMatMulRemain_int4 = MNNPackedMatMulRemain_int4;
     gCoreFunction->MNNPackedMatMul_int8 = MNNPackedMatMul_int8;
     gCoreFunction->MNNPackedMatMulRemain_int8 = MNNPackedMatMulRemain_int8;
+    gCoreFunction->MNNAbsMax = MNNAbsMaxFP32;
+    gCoreFunction->MNNDynamicQuant = MNNDynamicQuantFP32;
+    gCoreFunction->MNNQuantScale = MNNQuantScaleFP32;
+    gCoreFunction->MNNQuantSum = MNNQuantSumFP32;
 #endif
 
     gCoreFunction->MNNGetSparseMatMulPackMode = MNNGetSparseMatMulPackMode;
@@ -3230,6 +3351,7 @@ void MNNCoreFunctionInit() {
     gCoreFunction->MNNScaleAndAddBias = MNNScaleAndAddBias;
     gCoreFunction->MNNGridSampleComputeCord = MNNGridSampleComputeCord;
     gCoreFunction->MNNGridSampleInterp = MNNGridSampleInterp;
+    gCoreFunction->MNNGridSampleInterpGrad = MNNGridSampleInterpGrad;
     gCoreFunction->MNNGridSampleComputeCord3D = MNNGridSampleComputeCord3D;
     gCoreFunction->MNNGridSampleInterp3D = MNNGridSampleInterp3D;
     gCoreFunction->MNNRoiPoolingMax = MNNRoiPoolingMax;
@@ -3278,6 +3400,18 @@ void MNNCoreFunctionInit() {
     gCoreFunction->supportFp16arith = gCPUInfo.fp16arith;
     gCoreFunction->supportSDot = gCPUInfo.dot;
     gCoreFunction->supportI8mm = gCPUInfo.i8mm;
+#ifdef MNN_LOW_MEMORY
+    gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32_sdot;
+    gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32_sdot;
+    if (gCoreFunction->supportSDot) {
+        gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32_sdot;
+        gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32_sdot;
+    }
+    if (gCoreFunction->supportI8mm) {
+        gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32_smmla;
+        gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32_smmla;
+    }
+#endif
     MNNCoreInt8FunctionInit();
     MNNFunctionInit();
 }
@@ -3309,6 +3443,10 @@ void MNNUnpackC2(double* dst, const double* src, size_t area, size_t depth, int*
     MNNUnpackC2Common<double>(dst, src, area, depth, areaOffset);
 }
 
+void MNNUnpackC2Float(float* dst, const float* src, size_t area, size_t depth, int* areaOffset, int pack) {
+    MNNUnpackC2Common<float>(dst, src, area, depth, areaOffset, pack);
+}
+
 void MNNPackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) {
     MNNPackC2Common<float>(dst, src, area, depth, areaOffset);
 }

+ 23 - 11
source/backend/cpu/compute/CommonOptFunction.h

@@ -47,6 +47,7 @@ void MNNUnpackC4Origin(float* dst, const float* src, size_t area, size_t depth,
 
 void MNNUnpackC2(double* dst, const double* src, size_t area, size_t depth, int* areaOffset);
 void MNNUnpackC2Origin(double* dst, const double* src, size_t area, size_t depth, int areaOffset);
+void MNNUnpackC2Float(float* dst, const float* src, size_t area, size_t depth, int* areaOffset, int pack = 1);
 
 void MNNUnpackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
 void MNNUnpackInt8C2Origin(float* dst, const float* src, size_t area, size_t depth, int areaOffset);
@@ -121,6 +122,11 @@ void MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t
 void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
 void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
 void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
+void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
+void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
+void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack);
+void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch);
+
 
 void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose);
 struct SparseMatMulParas
@@ -142,6 +148,7 @@ void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t
 
 // dim: 4-element, sizeDW, sizeDH, strideSW, strideDH
 void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim); // not C4
+void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim); // not C4
 
 void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit);
 void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit);
@@ -160,19 +167,16 @@ void MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int
 void MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* postParameter);
 void MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow);
 void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
+void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt8FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt8FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt4FP16_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt8FP16_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt4FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+void MNNGemmHybridInt8FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
 }
 
-
-void MNNGridSampleComputeCord(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
-void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW,
-                            size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode);
-void MNNGridSampleComputeCord3D(float* dst, const float* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, size_t stride, bool alignCorners);
-void MNNGridSampleInterp3D(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inD, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode);
-void MNNRoiPoolingMax(float* dst, const float* src, int hLen, int wLen, int iw);
-void MNNRoiAlignMax(float* dst, const float* src, const std::vector<std::vector<int>> &vecPos, const std::vector<std::vector<float>> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth);
-void MNNRoiAlignAvg(float* dst, const float* src, const std::vector<std::vector<int>> &vecPos, const std::vector<std::vector<float>> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth);
-
-
 typedef void(*MNNBinaryExecute)(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex);
 typedef void(*MNNUnaryExecute)(void* outputRaw, const void* inputRaw, int elementSize);
 typedef void(*MNNUnaryExecuteInt8)(void* outputRaw, const void* inputRaw, int elementSize, QuanPrePostParameters* params);
@@ -197,6 +201,12 @@ struct CoreFunctions {
     void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
     void(*MNNPackedMatMul_int4)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
     void(*MNNPackedMatMulRemain_int4)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
+    void(*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
+    void(*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
+    void(*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack);
+    void(*MNNQuantSum)(float* sum, const float* dequant_scale, size_t thread, size_t batch);
+    void(*MNNGemmHybridInt4)(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
+    void(*MNNGemmHybridInt8)(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param);
     void(*MNNPackedMatMul_int8)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
     void(*MNNPackedMatMulRemain_int8)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
     void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
@@ -260,6 +270,8 @@ struct CoreFunctions {
     void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
     void(*MNNGridSampleComputeCord)(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
     void(*MNNGridSampleInterp)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode);
+    void (*MNNGridSampleInterpGrad)(float* outputPtr, float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode);
+
     void(*MNNGridSampleComputeCord3D)(float* dst, const float* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, size_t stride1, size_t stride2, bool alignCorners);
     void(*MNNGridSampleInterp3D)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inD, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode);
     void(*MNNRoiPoolingMax)(float* dst, const float* src, int hLen, int wLen, int iw);

+ 10 - 0
source/backend/cpu/compute/ConvolutionFloatFactory.cpp

@@ -15,6 +15,7 @@
 
 #include "backend/cpu/compute/ConvolutionWinogradBridge.hpp"
 #include "backend/cpu/compute/DenseConvolutionTiledExecutor.hpp"
+#include "backend/cpu/compute/ConvolutionHybrid.hpp"
 #ifdef MNN_USE_SPARSE_COMPUTE
 #include "backend/cpu/compute/SparseConvolutionTiledExecutor.hpp"
 #endif
@@ -48,6 +49,15 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
     bool fastWay = common->kernelY() == 1 && common->kernelX() == 1
         && output->width() == input->width() && output->height() == input->height()
         && common->strideX() == 1 && common->strideY() == 1;
+
+    if (lowMemory) {
+        if (fastWay) {
+            // return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
+            return new ConvolutionHybrid(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
+        } else {
+            return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
+        }
+    }
     if (fastWay) {
         return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
     }

+ 0 - 0
source/backend/cpu/compute/ConvolutionHybrid.cpp


Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels