Browse Source

MNN:Sync: Sync Interal 3.0.2

xiaying 7 months ago
parent
commit
da4023c222
100 changed files with 3473 additions and 2757 deletions
  1. 0 7
      .gitignore
  2. 18 0
      3rd_party/OpenCLHeaders/CL/cl2.hpp
  3. 22 0
      3rd_party/OpenCLHeaders/CL/cl_ext.h
  4. 27 5
      CMakeLists.txt
  5. 0 0
      backupcode/cpubackend/compute/DeconvolutionWithStride.cpp
  6. 0 0
      backupcode/cpubackend/compute/DeconvolutionWithStride.hpp
  7. 0 0
      backupcode/cpubackend/compute/GemmInt8Executor.cpp
  8. 0 0
      backupcode/cpubackend/compute/GemmInt8Executor.hpp
  9. 3 1
      docs/compile/cmake.md
  10. 13 0
      docs/compile/other.md
  11. 16 35
      docs/transformers/llm.md
  12. 2 1
      docs/transformers/models.md
  13. 19 4
      express/MathOp.cpp
  14. 121 121
      express/NeuralNetWorkOp.cpp
  15. 1 1
      include/MNN/MNNDefine.h
  16. 5 3
      include/MNN/MNNForwardType.h
  17. 8 5
      include/MNN/expr/MathOp.hpp
  18. 1 1
      project/android/build_32.sh
  19. 1 1
      project/android/build_64.sh
  20. 3 35
      project/ios/MNN.xcodeproj/project.pbxproj
  21. 6 1
      pymnn/CMakeLists.txt
  22. 1 0
      pymnn/pip_package/MNN/__init__.py
  23. 96 0
      pymnn/pip_package/MNN/audio/__init__.py
  24. 20 2
      pymnn/pip_package/MNN/llm/__init__.py
  25. 3 3
      pymnn/pip_package/build_deps.py
  26. 4 1
      pymnn/pip_package/setup.py
  27. 17 1
      pymnn/src/MNN.cc
  28. 105 0
      pymnn/src/audio.h
  29. 39 2
      pymnn/src/llm.h
  30. 192 12
      schema/current/MNN_generated.h
  31. 9 1
      schema/default/MNN.fbs
  32. 3 0
      source/backend/cpu/CPUBackend.hpp
  33. 8 8
      source/backend/cpu/CPUBinaryInt8.cpp
  34. 120 254
      source/backend/cpu/CPUDeconvolution.cpp
  35. 5 6
      source/backend/cpu/CPUDeconvolution.hpp
  36. 7 4
      source/backend/cpu/CPUInstanceNorm.cpp
  37. 7 4
      source/backend/cpu/CPUMoments.cpp
  38. 6 0
      source/backend/cpu/CPUOPRegister.cpp
  39. 72 28
      source/backend/cpu/CPURelu.cpp
  40. 75 0
      source/backend/cpu/CPUStft.cpp
  41. 31 0
      source/backend/cpu/CPUStft.hpp
  42. 47 15
      source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S
  43. 0 225
      source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S
  44. 0 223
      source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S
  45. 46 15
      source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S
  46. 0 171
      source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S
  47. 0 164
      source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S
  48. 20 2
      source/backend/cpu/compute/CommonOptFunction.cpp
  49. 2 1
      source/backend/cpu/compute/CommonOptFunction.h
  50. 11 13
      source/backend/cpu/compute/Int8FunctionsOpt.cpp
  51. 1 1
      source/backend/cpu/compute/Int8FunctionsOpt.h
  52. 0 67
      source/backend/cpu/compute/WinogradOptFunction.cpp
  53. 0 3
      source/backend/cpu/compute/WinogradOptFunction.hpp
  54. 0 1
      source/backend/cpu/x86_x64/FunctionDispatcher.cpp
  55. 1 1
      source/backend/cpu/x86_x64/sse/FunctionSummary.hpp
  56. 0 43
      source/backend/cpu/x86_x64/sse/MathFunctions.cpp
  57. 32 34
      source/backend/metal/AllShader.cpp
  58. 164 668
      source/backend/metal/MetalAttention.mm
  59. 636 0
      source/backend/metal/MetalAttentionShader.hpp
  60. 1 1
      source/backend/metal/MetalConvolution1x1.hpp
  61. 21 6
      source/backend/metal/MetalConvolution1x1.mm
  62. 16 2
      source/backend/metal/MetalConvolutionCommon.hpp
  63. 37 11
      source/backend/metal/MetalConvolutionCommon.mm
  64. 0 9
      source/backend/metal/MetalDeconvolution.hpp
  65. 64 44
      source/backend/metal/MetalDeconvolution.mm
  66. 29 28
      source/backend/metal/shader/MetalConvolution1x1.metal
  67. 3 6
      source/backend/metal/shader/MetalDeconvolution.metal
  68. 74 0
      source/backend/opencl/core/BufferConvertor.cpp
  69. 2 0
      source/backend/opencl/core/BufferConvertor.hpp
  70. 0 1
      source/backend/opencl/core/BufferPool.cpp
  71. 67 85
      source/backend/opencl/core/OpenCLBackend.cpp
  72. 21 3
      source/backend/opencl/core/OpenCLBackend.hpp
  73. 32 51
      source/backend/opencl/core/runtime/OpenCLRuntime.cpp
  74. 4 4
      source/backend/opencl/core/runtime/OpenCLRuntime.hpp
  75. 36 60
      source/backend/opencl/core/runtime/OpenCLWrapper.cpp
  76. 12 10
      source/backend/opencl/core/runtime/OpenCLWrapper.hpp
  77. 5 6
      source/backend/opencl/execution/buffer/ConvBufExecution.cpp
  78. 4 4
      source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp
  79. 10 2
      source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp
  80. 1 1
      source/backend/opencl/execution/cl/buffer_convert_buf.cl
  81. 28 0
      source/backend/opencl/execution/cl/conv_2d.cl
  82. 107 60
      source/backend/opencl/execution/cl/conv_2d_buf.cl
  83. 82 43
      source/backend/opencl/execution/cl/conv_2d_int_buf.cl
  84. 14 6
      source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl
  85. 211 0
      source/backend/opencl/execution/cl/glmem_convert.cl
  86. 432 108
      source/backend/opencl/execution/cl/opencl_program.cc
  87. 2 0
      source/backend/opencl/execution/cl/opencl_source_map.hpp
  88. 22 6
      source/backend/opencl/execution/image/ConvExecution.cpp
  89. 12 0
      source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp
  90. 1 1
      source/backend/vulkan/buffer/execution/VulkanPRelu.cpp
  91. 3 0
      source/core/Interpreter.cpp
  92. 1 1
      source/core/Pipeline.cpp
  93. 11 1
      source/core/TensorUtils.cpp
  94. 6 0
      source/core/TensorUtils.hpp
  95. 1 1
      source/shape/ShapeConcat.cpp
  96. 6 0
      source/shape/ShapeRegister.cpp
  97. 38 0
      source/shape/ShapeStft.cpp
  98. 9 0
      source/shape/SizeComputer.hpp
  99. 2 2
      test.sh
  100. 0 0
      test/CMakeLists.txt

+ 0 - 7
.gitignore

@@ -361,10 +361,3 @@ pymnn_build/
 
 # mnncompress generated
 MNN_compression_pb2.py
-
-# model path
-model/
-
-# datasets
-datasets/*
-!datasets/*.sh

+ 18 - 0
3rd_party/OpenCLHeaders/CL/cl2.hpp

@@ -3810,6 +3810,24 @@ public:
         }
     }
 
+    Buffer(
+        const Context& context,
+        cl_mem_flags flags,
+        const cl_import_properties_arm *properties,
+        void *memory,
+        size_type size,
+        cl_int* err = NULL)
+    {
+        cl_int error;
+        object_ = ::clImportMemoryARM(context(), flags, properties, memory, size, &error);
+
+        detail::errHandler(error, __CREATE_BUFFER_ERR);
+        if (err != NULL) {
+            *err = error;
+        }
+    }
+
+
     /*!
      * \brief Construct a Buffer from a host container via iterators using a specified context.
      * IteratorType must be random access.

+ 22 - 0
3rd_party/OpenCLHeaders/CL/cl_ext.h

@@ -430,6 +430,23 @@ typedef struct _cl_mem_android_native_buffer_host_ptr
 } cl_mem_android_native_buffer_host_ptr;
 
 
+/*********************************
+* cl_qcom_ahardwarebuffer_host_ptr extension
+*********************************/
+
+#define CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM                0x4119
+
+typedef struct _cl_mem_ahardwarebuffer_host_ptr
+{
+    /* Type of external memory allocation. */
+    /* Must be CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM for Android Hardware buffers. */
+    cl_mem_ext_host_ptr  ext_host_ptr;
+
+    /* Virtual pointer to the android hardware buffer */
+    void*                ahb_ptr;
+
+} cl_mem_ahardwarebuffer_host_ptr;
+
 /******************************************
  * cl_img_yuv_image extension *
  ******************************************/
@@ -583,6 +600,11 @@ typedef intptr_t cl_import_properties_arm;
 
 /* Protected DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */
 #define CL_IMPORT_TYPE_PROTECTED_ARM              0x40B5
+#define CL_IMPORT_TYPE_ANDROID_HARDWARE_BUFFER_ARM          0x41E2
+#define CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM    0x41E3
+#define CL_IMPORT_MEMORY_WHOLE_ALLOCATION_ARM               SIZE_MAX
+#define CL_IMPORT_ANDROID_HARDWARE_BUFFER_PLANE_INDEX_ARM   0x41EF
+#define CL_IMPORT_ANDROID_HARDWARE_BUFFER_LAYER_INDEX_ARM   0x41F0
 
 /* This extension adds a new function that allows for direct memory import into
  * OpenCL via the clImportMemoryARM function.

+ 27 - 5
CMakeLists.txt

@@ -20,9 +20,7 @@ endif()
 project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM)
 # complier options
 set(CMAKE_C_STANDARD 99)
-IF (NOT (CMAKE_CXX_STANDARD EQUAL 17))
-  set(CMAKE_CXX_STANDARD 11)
-ENDIF()
+set(CMAKE_CXX_STANDARD 11)
 set(CMAKE_MODULE_PATH
   ${CMAKE_MODULE_PATH}
   "${CMAKE_CURRENT_LIST_DIR}/cmake"
@@ -49,7 +47,7 @@ option(MNN_BUILD_TOOLS "Build tools/cpp or not" ON)
 option(MNN_BUILD_QUANTOOLS "Build Quantized Tools or not" OFF)
 option(MNN_EVALUATION "Build Evaluation Tools or not" OFF)
 option(MNN_BUILD_CONVERTER "Build Converter" OFF)
-option(MNN_SUPPORT_DEPRECATED_OP "Enable MNN's tflite quantized op" ON)
+option(MNN_SUPPORT_DEPRECATED_OP "Enable MNN's tflite quantized op" OFF)
 option(MNN_DEBUG_MEMORY "MNN Debug Memory Access" OFF)
 option(MNN_DEBUG_TENSOR_SIZE "Enable Tensor Size" OFF)
 option(MNN_GPU_TRACE "Enable MNN Gpu Debug" OFF)
@@ -74,6 +72,7 @@ option(MNN_JNI "Build MNN Jni for java to use" OFF)
 option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF)
 option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF)
 option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF)
+option(MNN_BUILD_AUDIO "Build audio api in MNN." OFF)
 
 IF (OHOS AND MNN_INTERNAL)
   include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
@@ -192,6 +191,9 @@ endif()
 if(MNN_SUPPORT_TRANSFORMER_FUSE)
     add_definitions(-DMNN_SUPPORT_TRANSFORMER_FUSE)
 endif()
+if(MNN_BUILD_AUDIO)
+    add_definitions(-DMNN_BUILD_AUDIO)
+endif()
 # debug options
 if(MNN_DEBUG_MEMORY)
     add_definitions(-DMNN_DEBUG_MEMORY)
@@ -287,7 +289,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android")
 endif()
 option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
 if (NOT MSVC)
-    if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17))
+    if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE)
         set(CMAKE_CXX_STANDARD 17)
         set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
@@ -463,6 +465,10 @@ IF(MNN_BUILD_OPENCV)
   list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_HDRS})
   list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_IMGHDRS})
 ENDIF()
+IF(MNN_BUILD_AUDIO)
+  file(GLOB MNN_AUDIO_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/tools/audio/include/audio/*.hpp PARENT_SCOPE)
+  list(APPEND MNN_EXTRA_HEADERS ${MNN_AUDIO_HDRS})
+ENDIF()
 IF(MNN_BUILD_LLM)
   file(GLOB MNN_LLM_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/*)
   list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/llm.hpp)
@@ -775,6 +781,14 @@ IF(MNN_BUILD_OPENCV AND NOT MNN_SEP_BUILD)
   ENDIF()
   target_sources(MNN PRIVATE $<TARGET_OBJECTS:MNNOpenCV>)
 ENDIF()
+add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tools/audio)
+IF(MNN_BUILD_AUDIO AND NOT MNN_SEP_BUILD)
+  IF(MSVC)
+    target_compile_definitions(MNNAudio PRIVATE "-DBUILDING_MNN_DLL" INTERFACE "-DUSING_MNN_DLL")
+  ENDIF()
+  message(STATUC "### build MNNAudio into MNN")
+  target_sources(MNN PRIVATE $<TARGET_OBJECTS:MNNAudio>)
+ENDIF()
 
 
 if(CMAKE_SYSTEM_NAME MATCHES "^Linux")
@@ -884,6 +898,14 @@ ELSE()
       SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/cv/imgproc )
     ENDFOREACH()
   ENDIF()
+  IF(MNN_BUILD_AUDIO)
+    if (NOT MNN_AAPL_FMWK)
+      INSTALL(FILES ${MNN_AUDIO_HDRS} DESTINATION include/MNN/audio)
+    endif()
+    FOREACH(HDR ${MNN_AUDIO_HDRS})
+      SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/audio/ )
+    ENDFOREACH()
+  ENDIF()
   IF(MNN_BUILD_LLM)
     if (NOT MNN_AAPL_FMWK)
         INSTALL(FILES ${MNN_LLM_HDRS} DESTINATION include/MNN/llm)

source/backend/cpu/compute/DeconvolutionWithStride.cpp → backupcode/cpubackend/compute/DeconvolutionWithStride.cpp


source/backend/cpu/compute/DeconvolutionWithStride.hpp → backupcode/cpubackend/compute/DeconvolutionWithStride.hpp


source/backend/cpu/compute/GemmInt8Executor.cpp → backupcode/cpubackend/compute/GemmInt8Executor.cpp


source/backend/cpu/compute/GemmInt8Executor.hpp → backupcode/cpubackend/compute/GemmInt8Executor.hpp


+ 3 - 1
docs/compile/cmake.md

@@ -16,7 +16,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
 | MNN_BUILD_QUANTOOLS  | 是否构建MNN的量化工具,默认为`OFF` |
 | MNN_EVALUATION       | 是否构建MNN的评估工具,默认为`OFF` |
 | MNN_BUILD_CONVERTER  | 是否构建MNN的转换工具,默认为`OFF` |
-| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子,默认为`ON` |
+| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子,用于兼容历史模型(1.1.0版本之前),默认为`OFF` |
 | MNN_DEBUG_MEMORY     | 是否开启MNN内存调试,默认为`OFF` |
 | MNN_DEBUG_TENSOR_SIZE | 是否开启MNN tensor size调试,默认为`OFF` |
 | MNN_GPU_TRACE        | 是否开启MNN GPU调试,默认为`OFF` |
@@ -32,6 +32,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
 | MNN_ENABLE_COVERAGE  | 是否开启MNN的代码覆盖率,默认为`OFF` |
 | MNN_BUILD_PROTOBUFFER | 是否使用MNN中的`protobuffer`,默认为`ON` |
 | MNN_BUILD_OPENCV     | 是否构建MNN的OpenCV功能,默认为`OFF` |
+| MNN_BUILD_AUDIO      | 是否构建MNN的Audio功能,默认为`OFF` |
 | MNN_INTERNAL         | 是否构建MNN的一些内部功能,如:日志;默认为`OFF` |
 | MNN_JNI              | 是否构建MNN的JNI支持,默认为`OFF` |
 | MNN_METAL            | 是否构建`Metal`后端,默认为`OFF` |
@@ -79,6 +80,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
 | MNN_CVCORE           | 构建MNN的OpenCV功能是否开启`core`功能,默认为`ON` |
 | MNN_OPENCV_TEST      | 构建MNN的OpenCV功能是否开启单元测试,默认为`OFF` |
 | MNN_OPENCV_BENCH     | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` |
+| MNN_AUDIO_TEST       | 构建MNN的Audio功能是否开启单元测试,默认为`OFF` |
 | MNN_VULKAN_IMAGE     | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` |
 | MNN_LOW_MEMORY       | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` |
 | MNN_CPU_WEIGHT_DEQUANT_GEMM       | 是否编译CPU权重反量化的矩阵乘Kernel, 如果打开该编译宏并且在CPU推理时设置MNN::BackendConfig::MemoryMode=Memory_Normal,就会使用权重反量化算子进行权重量化模型的推理,默认为`OFF` |

+ 13 - 0
docs/compile/other.md

@@ -133,6 +133,19 @@
   - `libMNNOpenCV.so` MNN OpenCV函数库
   - `opencv_test` MNN OpenCV单元测试
   - `opencv_bench` MNN OpenCV性能测试
+## MNN Audio库
+- 相关编译选项
+  - `MNN_BUILD_AUDIO` 是否编译Audio函数接口
+  - `MNN_AUDIO_TEST` 是否编译Audio单元测试
+- 编译命令
+    ```bash
+    mkdir build && cd build
+    cmake .. -MNN_BUILD_AUDIO=ON -MNN_AUDIO_TEST=ON
+    make -j4
+    ```
+- 编译产物
+  - `libMNNAudio.so` MNN Audio函数库
+  - `audio_test` MNN Audio单元测试
 
 ## 示例工程
 - 相关编译选项

+ 16 - 35
docs/transformers/llm.md

@@ -49,7 +49,7 @@ python llmexport.py \
 
 ### 功能
 - 直接转为mnn模型,使用`--export mnn`,注意,你需要先安装pymnn或者通过`--mnnconvert`选项指定MNNConvert工具的地址,两种条件必须满足其中一个。如果没有安装pymnn并且没有通过`--mnnconvert`指定MNNConvert工具的地址,那么llmexport.py脚本会在目录"../../../build/"下寻找MNNConvert工具,需保证该目录下存在MNNConvert文件。此方案目前支持导出4bit和8bit模型
-- 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型: 
+- 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型:
 
 ```
 ./MNNConvert --modelFile ../transformers/llm/export/model/onnx/llm.onnx --MNNModel llm.mnn --keepInputFormat --weightQuantBits=4 --weightQuantBlock=128 -f ONNX --transformerFuse=1 --allowCustomOp --saveExternalData
@@ -98,13 +98,17 @@ options:
 [从源码编译](../compile/other.html#id4)
 在原有编译过程中增加必需编译宏即可:
 ```
--DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true 
+-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
 ```
 
 - 需要开启视觉功能时,增加相关编译宏
 ```
 -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
 ```
+- 需要开启音频功能时,增加相关编译宏
+```
+-DLLM_SUPPORT_AUDIO=true
+```
 
 #### mac / linux / windows
 
@@ -137,7 +141,7 @@ sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN
 ```
 
 #### Web
-环境配置参考 https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web 
+环境配置参考 https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web
 
 - 编译库,产出 `libMNN.a`,`libMNN_Express.a`,`libllm.a`
 
@@ -189,7 +193,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
   - visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'`
 - 推理配置
   - max_new_tokens: 生成时最大token数,默认为`512`
-  - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`, 目前只有CPU后端支持设置为`true`.
+  - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`
   - quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下:
     - 0: key和value都不量化
     - 1: 使用非对称8bit量化存储key
@@ -205,19 +209,6 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
   - thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68`
   - precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16`
   - memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化
-- Sampler配置
-  - sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler)。当选择`mixed`时,依次执行mixed_samplers中的sampler。默认为`mixed`。
-  - mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]`
-  - temperature: `temperature`, `topP`, `minP`, `tfsZ`, `typical`中temerature值,默认为1.0
-  - topK: `topK`中top K 个的个数,默认为40
-  - topP: `topP`中top P的值,默认为0.9
-  - minP: `minP`中min P的值,默认为0.1
-  - tfsZ: `tfs`中Z的值,默认为1.0,即不使用tfs算法
-  - typical: `typical`中p的值,默认为1.0,即不使用typical算法
-  - penalty: `penalty`中对于logits的惩罚项,默认为0.0,即不惩罚
-  - n_gram: `penalty`中最大存储的ngram大小,默认为8
-  - ngram_factor: `penalty`中对于重复ngram的额外惩罚,默认为1.0,即没有额外惩罚
-  - penalty_sampler: `penalty`中最后一步采用的sampling策略,可选"greedy"或"temperature",默认greedy.
 
 ##### 配置文件示例
 - `config.json`
@@ -229,15 +220,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
       "backend_type": "cpu",
       "thread_num": 4,
       "precision": "low",
-      "memory": "low",
-      "sampler_type": "mixed",
-      "mixed_samplers": ["topK", "tfs", "typical", "topP", "min_p", "temperature"],
-      "temperature": 1.0,
-      "topK": 40,
-      "topP": 0.9,
-      "tfsZ": 1.0,
-      "minP": 0.1,
-      "reuse_kv": true
+      "memory": "low"
   }
   ```
 - `llm_config.json`
@@ -261,8 +244,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
 
 #### 推理用法
 `llm_demo`的用法如下:
-pc端直接推理
-```bash
+```
 # 使用config.json
 ## 交互式聊天
 ./llm_demo model_dir/config.json
@@ -276,16 +258,15 @@ pc端直接推理
 ./llm_demo model_dir/llm.mnn prompt.txt
 ```
 
-android手机端adb推理用法:
-```bash
-# 利用adb push将链接库push到手机上
-adb shell mkdir /data/local/tmp/llm
-adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
-```
-
 - 对于视觉大模型,在prompt中嵌入图片输入
 ```
 <img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>介绍一下图片里的内容
+# 指定图片大小
+<img><hw>280, 420</hw>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>介绍一下图片里的内容
+```
+- 对于音频大模型,在prompt中嵌入音频输入
+```
+<audio>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav</audio>介绍一下音频里的内容
 ```
 
 #### GPTQ权重加载

+ 2 - 1
docs/transformers/models.md

@@ -47,4 +47,5 @@
 | [reader-lm-0.5b](https://huggingface.co/jinaai/reader-lm-0.5b) | [Q4_1](https://modelscope.cn/models/MNN/reader-lm-0.5b-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/reader-lm-0.5b-MNN) |
 | [reader-lm-1.5b](https://huggingface.co/jinaai/reader-lm-1.5b) | [Q4_1](https://modelscope.cn/models/MNN/reader-lm-1.5b-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/reader-lm-1.5b-MNN) |
 | [TinyLlama-1.1B-Chat-v1.0](https://modelscope.cn/models/AI-ModelScope/TinyLlama-1.1B-Chat-v1.0/summary) | [Q4_1](https://modelscope.cn/models/MNN/TinyLlama-1.1B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/TinyLlama-1.1B-Chat-MNN) |
-| [Yi-6B-Chat](https://modelscope.cn/models/01ai/Yi-6B-Chat/summary) | [Q4_1](https://modelscope.cn/models/MNN/Yi-6B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/Yi-6B-Chat-MNN) |
+| [Yi-6B-Chat](https://modelscope.cn/models/01ai/Yi-6B-Chat/summary) | [Q4_1](https://modelscope.cn/models/MNN/Yi-6B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/Yi-6B-Chat-MNN) |
+| [QwQ-32B-Preview](https://modelscope.cn/models/Qwen/QwQ-32B-Preview/summary) | [Q4_1](https://modelscope.cn/models/MNN/QwQ-32B-Preview-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/QwQ-32B-Preview-MNN) |

+ 19 - 4
express/MathOp.cpp

@@ -1208,7 +1208,7 @@ VARP _LinSpace(VARP start, VARP stop, VARP num) {
     return (Variable::create(Expr::create(std::move(op), {start, stop, num})));
 }
 
-VARP _EltwiseProdInt8(VARP x, VARP y, 
+VARP _EltwiseProdInt8(VARP x, VARP y,
                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
@@ -1219,7 +1219,7 @@ VARP _EltwiseProdInt8(VARP x, VARP y,
                         output_weight, output_bias, output_scale, output_tensorScale);
 }
 
-VARP _EltwiseSumInt8(VARP x, VARP y, 
+VARP _EltwiseSumInt8(VARP x, VARP y,
                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
@@ -1230,7 +1230,7 @@ VARP _EltwiseSumInt8(VARP x, VARP y,
                         output_weight, output_bias, output_scale, output_tensorScale);
 }
 
-VARP _EltwiseSubInt8(VARP x, VARP y, 
+VARP _EltwiseSubInt8(VARP x, VARP y,
                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
@@ -1241,7 +1241,7 @@ VARP _EltwiseSubInt8(VARP x, VARP y,
                         output_weight, output_bias, output_scale, output_tensorScale);
 }
 
-VARP _EltwiseMaxInt8(VARP x, VARP y, 
+VARP _EltwiseMaxInt8(VARP x, VARP y,
                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale)
@@ -1320,5 +1320,20 @@ VARP _Histogram(VARP x, int bin, int min, int max, int channel) {
     return (Variable::create(Expr::create(std::move(op), {x})));
 }
 
+#ifdef MNN_BUILD_AUDIO
+VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abs) {
+    std::unique_ptr<OpT> op(new OpT);
+    op->type      = OpType_Stft;
+    op->main.type = OpParameter_StftParam;
+    auto param = new StftParamT;
+    param->n_fft = n_fft;
+    param->hop_length = hop_length;
+    param->abs = abs;
+    op->main.value = param;
+    EXPRP expr = Expr::create(std::move(op), {sample, window});
+    return Variable::create(expr);
+}
+#endif
+
 } // namespace Express
 } // namespace MNN

File diff suppressed because it is too large
+ 121 - 121
express/NeuralNetWorkOp.cpp


+ 1 - 1
include/MNN/MNNDefine.h

@@ -75,6 +75,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
 #define STR(x) STR_IMP(x)
 #define MNN_VERSION_MAJOR 3
 #define MNN_VERSION_MINOR 0
-#define MNN_VERSION_PATCH 1
+#define MNN_VERSION_PATCH 2
 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
 #endif /* MNNDefine_h */

+ 5 - 3
include/MNN/MNNForwardType.h

@@ -40,14 +40,16 @@ typedef enum {
     MNN_FORWARD_USER_2 = 10,
     MNN_FORWARD_USER_3 = 11,
 
-    MNN_FORWARD_ALL,
+    MNN_FORWARD_ALL = 12,
 
     /* Apply arm extension instruction set to accelerate some Ops, this forward type
        is only used in MNN internal, and will be active automatically when user set forward type
        to be MNN_FORWARD_CPU and extension instruction set is valid on hardware.
     */
-    MNN_FORWARD_CPU_EXTENSION
-
+    MNN_FORWARD_CPU_EXTENSION = 13,
+    // use for shared memory on android device
+    
+    MNN_MEMORY_AHARDWAREBUFFER = 14
 } MNNForwardType;
 
 typedef enum {

+ 8 - 5
include/MNN/expr/MathOp.hpp

@@ -13,7 +13,7 @@ namespace MNN {
 namespace Express {
 //BinaryOPs
 MNN_PUBLIC VARP _Add(VARP x, VARP y);
-MNN_PUBLIC VARP _Subtract(VARP x, VARP y);    
+MNN_PUBLIC VARP _Subtract(VARP x, VARP y);
 MNN_PUBLIC VARP _Multiply(VARP x, VARP y);
 MNN_PUBLIC VARP _Divide(VARP x, VARP y);
 MNN_PUBLIC VARP _Pow(VARP x, VARP y);
@@ -92,19 +92,19 @@ MNN_PUBLIC VARP _Prod(VARP a, VARP b, std::vector<float> coeff);
 MNN_PUBLIC VARP _Sum(VARP a, VARP b, std::vector<float> coeff);
 MNN_PUBLIC VARP _Max(VARP a, VARP b, std::vector<float> coeff);
 MNN_PUBLIC VARP _Sub(VARP a, VARP b, std::vector<float> coeff);
-MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y, 
+MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y,
                     std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
-MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y, 
+MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y,
                      std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
-MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y, 
+MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y,
                      std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
-MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y, 
+MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y,
                       std::vector<int8_t> x_weight, std::vector<int32_t> x_bias, std::vector<float> x_scale, std::vector<float> x_tensorScale,
                     std::vector<int8_t> y_weight, std::vector<int32_t> y_bias, std::vector<float> y_scale, std::vector<float> y_tensorScale,
                     std::vector<int8_t> output_weight, std::vector<int32_t> output_bias, std::vector<float> output_scale, std::vector<float> output_tensorScale);
@@ -138,6 +138,9 @@ MNN_PUBLIC VARP _CumSum(VARP x, int axis, bool exclusive = false, bool reverse =
 MNN_PUBLIC VARP _CumProd(VARP x, int axis);
 MNN_PUBLIC VARPS _Svd(VARP x);
 MNN_PUBLIC VARP _Histogram(VARP x, int bin, int min, int max, int channel = -1);
+#ifdef MNN_BUILD_AUDIO
+MNN_PUBLIC VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abse = true);
+#endif
 }; // namespace Express
 }; // namespace MNN
 

+ 1 - 1
project/android/build_32.sh

@@ -4,7 +4,7 @@ cmake ../../../ \
 -DCMAKE_BUILD_TYPE=Release \
 -DANDROID_ABI="armeabi-v7a" \
 -DANDROID_STL=c++_static \
--DANDROID_NATIVE_API_LEVEL=android-14  \
+-DANDROID_NATIVE_API_LEVEL=android-26  \
 -DANDROID_TOOLCHAIN=clang \
 -DMNN_USE_LOGCAT=false \
 -DMNN_USE_SSE=OFF \

+ 1 - 1
project/android/build_64.sh

@@ -8,7 +8,7 @@ cmake ../../../ \
 -DMNN_BUILD_BENCHMARK=ON \
 -DMNN_USE_SSE=OFF \
 -DMNN_BUILD_TEST=ON \
--DANDROID_NATIVE_API_LEVEL=android-21  \
+-DANDROID_NATIVE_API_LEVEL=android-26  \
 -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
 -DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3 $4 $5 $6 $7
 

+ 3 - 35
project/ios/MNN.xcodeproj/project.pbxproj

@@ -486,11 +486,9 @@
 		92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
 		92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; };
 		92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
-		92FF02E623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */; };
 		92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
 		92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
 		92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */; };
-		92FF02EC23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */; };
 		92FF02EE23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */; };
 		92FF02F223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */; };
 		92FF02F423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */; };
@@ -530,11 +528,9 @@
 		92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
 		92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; };
 		92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
-		92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */; };
 		92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
 		92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
 		92FF032A23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */; };
-		92FF032C23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */; };
 		92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */; };
 		92FF033223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */; };
 		92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */; };
@@ -592,7 +588,6 @@
 		92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */; };
 		92FF03A323AA0B5A00AC97F6 /* ConvOpt.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */; };
 		92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */; };
-		92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */; };
 		92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */; };
 		92FF03A723AA0B5A00AC97F6 /* ConvolutionIntFactory.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */; };
 		92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */; };
@@ -609,7 +604,6 @@
 		92FF03B923AA0B5A00AC97F6 /* ConvOpt.h in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023B23AA0B5600AC97F6 /* ConvOpt.h */; };
 		92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */; };
 		92FF03BD23AA0B5A00AC97F6 /* Int8FunctionsOpt.h in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */; };
-		92FF03BE23AA0B5A00AC97F6 /* DeconvolutionWithStride.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */; };
 		92FF03BF23AA0B5A00AC97F6 /* ConvolutionTiledExecutor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */; };
 		92FF03C323AA0B5A00AC97F6 /* CPUEltwise.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024523AA0B5700AC97F6 /* CPUEltwise.cpp */; };
 		92FF03C423AA0B5A00AC97F6 /* CPUInterp.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024623AA0B5700AC97F6 /* CPUInterp.cpp */; };
@@ -740,8 +734,6 @@
 		95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */; };
 		95772DD02C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */; };
 		958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; };
-		958B046429D2C89D00FC3AEF /* GemmInt8Executor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */; };
-		958B046629D2C8AF00FC3AEF /* GemmInt8Executor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */; };
 		95CE1DFF2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 95CE1DFE2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S */; };
 		95CE1E012AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 95CE1E002AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S */; };
 		C43C81FA251894A600A0FF84 /* CommonOptFunctionNeon.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */; };
@@ -1342,11 +1334,9 @@
 		92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
 		92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
 		92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
-		92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductLeft.S; sourceTree = "<group>"; };
 		92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
 		92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
 		92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit.S; sourceTree = "<group>"; };
-		92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductRight.S; sourceTree = "<group>"; };
 		92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannel.S; sourceTree = "<group>"; };
 		92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBlitC3ToFloatRGBA.S; sourceTree = "<group>"; };
 		92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNUInt8ToInt16WithOffsetC4Common.S; sourceTree = "<group>"; };
@@ -1386,11 +1376,9 @@
 		92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
 		92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
 		92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
-		92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductLeft.S; sourceTree = "<group>"; };
 		92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
 		92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
 		92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit.S; sourceTree = "<group>"; };
-		92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductRight.S; sourceTree = "<group>"; };
 		92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannel.S; sourceTree = "<group>"; };
 		92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBlitC3ToFloatRGBA.S; sourceTree = "<group>"; };
 		92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNUInt8ToInt16WithOffsetC4Common.S; sourceTree = "<group>"; };
@@ -1448,7 +1436,6 @@
 		92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = Int8FunctionsOpt.cpp; sourceTree = "<group>"; };
 		92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvOpt.cpp; sourceTree = "<group>"; };
 		92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = OptimizedComputer.cpp; sourceTree = "<group>"; };
-		92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = DeconvolutionWithStride.hpp; sourceTree = "<group>"; };
 		92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ConvolutionTiledExecutor.hpp; sourceTree = "<group>"; };
 		92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvolutionIntFactory.cpp; sourceTree = "<group>"; };
 		92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = WinogradOptFunction.cpp; sourceTree = "<group>"; };
@@ -1465,7 +1452,6 @@
 		92FF023B23AA0B5600AC97F6 /* ConvOpt.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ConvOpt.h; sourceTree = "<group>"; };
 		92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = OptimizedComputer.hpp; sourceTree = "<group>"; };
 		92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Int8FunctionsOpt.h; sourceTree = "<group>"; };
-		92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = DeconvolutionWithStride.cpp; sourceTree = "<group>"; };
 		92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvolutionTiledExecutor.cpp; sourceTree = "<group>"; };
 		92FF024523AA0B5700AC97F6 /* CPUEltwise.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUEltwise.cpp; sourceTree = "<group>"; };
 		92FF024623AA0B5700AC97F6 /* CPUInterp.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUInterp.cpp; sourceTree = "<group>"; };
@@ -1597,8 +1583,6 @@
 		95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM82.S; sourceTree = "<group>"; };
 		95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM86.S; sourceTree = "<group>"; };
 		958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = "<group>"; };
-		958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GemmInt8Executor.cpp; sourceTree = "<group>"; };
-		958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = GemmInt8Executor.hpp; sourceTree = "<group>"; };
 		95CE1DFE2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannelInt8.S; sourceTree = "<group>"; };
 		95CE1E002AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannelInt8.S; sourceTree = "<group>"; };
 		C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CommonOptFunctionNeon.cpp; sourceTree = "<group>"; };
@@ -2643,11 +2627,9 @@
 				92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
 				92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */,
 				92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
-				92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */,
 				92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
 				92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
 				92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */,
-				92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */,
 				92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */,
 				92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */,
 				92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */,
@@ -2737,11 +2719,9 @@
 				92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
 				92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */,
 				92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
-				92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */,
 				92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
 				92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
 				92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */,
-				92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */,
 				92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */,
 				92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */,
 				92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */,
@@ -2761,8 +2741,6 @@
 			children = (
 				CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */,
 				CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */,
-				958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */,
-				958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */,
 				C48CAE2528900C4A00271A6D /* ConvInt8Winograd.cpp */,
 				C48CAE2428900C4A00271A6D /* ConvInt8Winograd.hpp */,
 				4A224A1227D0C56E000A9260 /* ConvolutionWinogradBridge.cpp */,
@@ -2790,7 +2768,6 @@
 				92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */,
 				92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */,
 				92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */,
-				92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */,
 				92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */,
 				92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */,
 				92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */,
@@ -2807,7 +2784,6 @@
 				92FF023B23AA0B5600AC97F6 /* ConvOpt.h */,
 				92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */,
 				92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */,
-				92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */,
 				92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */,
 			);
 			path = compute;
@@ -2939,7 +2915,6 @@
 			buildActionMask = 2147483647;
 			files = (
 				48C84B89250F711700EE7666 /* StaticModule.hpp in Headers */,
-				958B046629D2C8AF00FC3AEF /* GemmInt8Executor.hpp in Headers */,
 				1F501F812397BA5B004E8721 /* AutoTime.hpp in Headers */,
 				92FF04A523AA0BFB00AC97F6 /* AutoStorage.h in Headers */,
 				EBECA3A124643D4E0062C7A3 /* MNNAsmGlobal.h in Headers */,
@@ -3105,7 +3080,6 @@
 				92FF03C923AA0B5A00AC97F6 /* CPUMatMul.hpp in Headers */,
 				EBECA39924643D320062C7A3 /* Arm82Relu.hpp in Headers */,
 				4838EA7C2611BFE20027232C /* CPUGridSample.hpp in Headers */,
-				92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */,
 				92FF03D123AA0B5A00AC97F6 /* CPUTopKV2.hpp in Headers */,
 				92FF033F23AA0B5A00AC97F6 /* CPUArgMax.hpp in Headers */,
 				92FF034C23AA0B5A00AC97F6 /* CPUSetDiff1D.hpp in Headers */,
@@ -3335,7 +3309,6 @@
 				92FF038623AA0B5A00AC97F6 /* CPULinSpace.cpp in Sources */,
 				4819FB2D24C1396A0050BD09 /* GeometryConv2D.cpp in Sources */,
 				48747D63245D9E33000B9709 /* GeometryPermute.cpp in Sources */,
-				92FF032C23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */,
 				48BB6EF625220AA80056E195 /* MNNTranspose32Bit4x4.S in Sources */,
 				CE072A1C2C91AEE700F190FD /* MNNRGBAToBGRFast.S in Sources */,
 				CEE9B95C2A3AA4D4006438F2 /* MNNBilinearSampleC8.S in Sources */,
@@ -3597,7 +3570,6 @@
 				48FD12BF2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp in Sources */,
 				92FF02F923AA0B5A00AC97F6 /* MNNGemmint8to32_8x4_Unit.S in Sources */,
 				95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */,
-				92FF02E623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */,
 				48747D64245D9E33000B9709 /* GeometryTile.cpp in Sources */,
 				92FF043723AA0B7100AC97F6 /* ShapeDetectionOutput.cpp in Sources */,
 				92FF042623AA0B7100AC97F6 /* ShapeCosineSimilarity.cpp in Sources */,
@@ -3633,7 +3605,6 @@
 				92FF043023AA0B7100AC97F6 /* ShapeQuantizedAvgPool.cpp in Sources */,
 				92FF030623AA0B5A00AC97F6 /* MNNStrassenMergeCFunction.S in Sources */,
 				92FF033223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */,
-				92FF03BE23AA0B5A00AC97F6 /* DeconvolutionWithStride.cpp in Sources */,
 				92FF044923AA0B7100AC97F6 /* ShapeGatherND.cpp in Sources */,
 				489D7AB32550FDC900AD896A /* MetalPReLU.mm in Sources */,
 				19D0FE7028534C4500B74B1A /* MetalSoftmax.mm in Sources */,
@@ -3787,13 +3758,11 @@
 				92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */,
 				92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
 				CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */,
-				92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */,
 				92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */,
 				CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */,
 				92FF045D23AA0B7100AC97F6 /* ShapeCast.cpp in Sources */,
 				92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */,
 				92FF02D723AA0B5A00AC97F6 /* MNNConvRunForUnitDepthWiseUint8.S in Sources */,
-				958B046429D2C89D00FC3AEF /* GemmInt8Executor.cpp in Sources */,
 				92FF026123AA0B5A00AC97F6 /* CPUCropAndResize.cpp in Sources */,
 				48FA474923AA127B00172C3B /* MathOp.cpp in Sources */,
 				4819FB3C24C69E680050BD09 /* GeometryBatchMatMul.cpp in Sources */,
@@ -3826,7 +3795,6 @@
 				92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */,
 				4896D37F25FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S in Sources */,
 				92FF044323AA0B7100AC97F6 /* ShapeTopKV2.cpp in Sources */,
-				92FF02EC23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */,
 				48C84BA1250F725600EE7666 /* InitNet.cpp in Sources */,
 				4894C6E927016F7200D8BE79 /* CPUResizeCache.cpp in Sources */,
 				4DD1791B2684815A00B0098F /* ShapeSetDiff1D.cpp in Sources */,
@@ -4164,7 +4132,7 @@
 				METAL_LIBRARY_FILE_BASE = mnn;
 				ONLY_ACTIVE_ARCH = YES;
 				OTHER_CFLAGS = "";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde;
 				PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
 				PROVISIONING_PROFILE_SPECIFIER = "";
 				"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = "";
@@ -4260,7 +4228,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};
@@ -4287,7 +4255,7 @@
 				IPHONEOS_DEPLOYMENT_TARGET = 9.0;
 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
-				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk;
+				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde;
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				TARGETED_DEVICE_FAMILY = "1,2";
 			};

+ 6 - 1
pymnn/CMakeLists.txt

@@ -16,6 +16,7 @@ option(PYMNN_TRAIN_API "MNN train API be exposed" OFF)
 option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
 option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
 option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
+option(PYMNN_AUDIO_API "MNN Audio API be exposed" ON)
 option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF)
 
 if (PYMNN_OHOS_INTERNAL)
@@ -91,6 +92,10 @@ if(PYMNN_CVCORE)
     target_compile_definitions(mnnpybridge PRIVATE PYMNN_CVCORE)
 endif()
 
+if(PYMNN_AUDIO_API)
+    target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API)
+endif()
+
 if(PYMNN_INTERNAL_SERVING)
     message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING")
     target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING)
@@ -197,7 +202,7 @@ else()
         endif()
         export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN)
     else()
-        target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
+        target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio)
         if(PYMNN_USE_ALINNPYTHON)
             target_link_libraries(mnnpybridge PRIVATE AliNNPython)
         endif()

+ 1 - 0
pymnn/pip_package/MNN/__init__.py

@@ -9,3 +9,4 @@ from . import nn
 from . import optim
 from . import numpy
 from . import cv
+from . import audio

+ 96 - 0
pymnn/pip_package/MNN/audio/__init__.py

@@ -0,0 +1,96 @@
+from _mnncengine.audio import *
+import _mnncengine.audio as _F
+import MNN.expr as _expr
+import MNN.numpy as _np
+import MNN
+
+# Enum Types
+# enum WINDOW_TYPE
+HAMMING = 0
+HANNING = 1
+POVEY = 2
+RECTANGULAR = 3
+BLACKMAN = 4
+# enum PadValueMode
+CONSTANT = 0
+REFLECT = 1
+SYMMETRIC = 2
+EDGE = 3
+
+"""
+Loads a portion of an audio file.
+
+Parameters:
+    filename (str): The path to the audio file.
+    frame_offset (int): The offset in frames from which to start loading the audio data. Default is 0.
+    num_frames (int): The number of frames to load. If set to -1, the entire audio file will be loaded. Default is -1.
+
+Returns:
+    The result of loading the specified portion of the audio var and the sample rate.
+"""
+def load(filename, sr = 0, frame_offset = 0, num_frames = -1):
+    return _F.load(filename, sr, frame_offset, num_frames)
+
+"""
+Saves an audio var to a file.
+Parameters:
+    filename (str): The path to the audio file.
+    audio (Var): The audio var to save.
+    sample_rate (int): The sample rate of the audio var.
+Returns:
+    None
+"""
+def save(filename, audio, sample_rate):
+    return _F.save(filename, audio, sample_rate)
+
+"""
+Generates a Hamming window.
+Parameters:
+    window_size (int): The size of the window.
+    periodic (bool): Whether the window is periodic. Default is False.
+    alpha (float): The alpha parameter of the Hamming window. Default is 0.54.
+    beta (float): The beta parameter of the Hamming window. Default is 0.46.
+Returns:
+    The Hamming window.
+"""
+def hamming_window(window_size, periodic = False, alpha = 0.54, beta = 0.46):
+    return _F.hamming_window(window_size, periodic, alpha, beta)
+
+"""
+Generates a Hann window.
+Parameters:
+    window_size (int): The size of the window.
+    periodic (bool): Whether the window is periodic. Default is False.
+Returns:
+    The Hann window.
+"""
+def hanning_window(window_size, periodic = False):
+    return _F.hanning_window(window_size, periodic)
+
+def melscale_fbanks(n_mels, n_fft, sampe_rate = 16000, htk = True, norm = False,
+                    f_min = 0.0, f_max = 0.0):
+    return _F.melscale_fbanks(n_mels, n_fft, sampe_rate, htk, norm, f_min, f_max)
+
+def spectrogram(waveform, n_fft = 400, hop_length = 0, win_length = 0, window_type = HANNING,
+                pad_left = 0, pad_right = 0, center = False, normalized = False, pad_mode = REFLECT,
+                power = 2.0):
+    return _F.spectrogram(waveform, n_fft, hop_length, win_length, window_type, pad_left,
+                          pad_right, center, normalized, pad_mode, power)
+
+
+def mel_spectrogram(waveform, n_mels, n_fft, sampe_rate = 16000, htk = True, norm = False,
+                    f_min = 0.0, f_max = 0.0, hop_length = 0, win_length = 0, window_type = HANNING,
+                    pad_left = 0, pad_right = 0, center = False, normalized = False, pad_mode = REFLECT,
+                    power = 2.0):
+    return _F.mel_spectrogram(waveform, n_mels, n_fft, sampe_rate, htk, norm, f_min, f_max,
+                              hop_length, win_length, window_type, pad_left, pad_right, center,
+                              normalized, pad_mode, power)
+
+def fbank(waveform, sample_rate = 16000, n_mels = 80, n_fft = 400, hop_length = 160,
+          dither = 0.0, preemphasis = 0.97):
+    return _F.fbank(waveform, sample_rate, n_mels, n_fft, hop_length, dither, preemphasis)
+
+
+def whisper_fbank(waveform, sample_rate = 16000, n_mels = 128, n_fft = 400,
+                  hop_length = 160, chunk_len = 30):
+    return _F.whisper_fbank(waveform, sample_rate, n_mels, n_fft, hop_length, chunk_len)

+ 20 - 2
pymnn/pip_package/MNN/llm/__init__.py

@@ -57,7 +57,25 @@ class LLM(_F.LLM):
         '''
         return super.response(prompt, stream)
 
-def create(config_path):
+    def txt_embedding(self, prompt):
+        '''
+        get prompt's embedding
+
+        Parameters
+        ----------
+        prompt : input prompt
+
+        Returns
+        -------
+        res : embedding var
+
+        Example:
+        -------
+        >>> res = qwen.txt_embedding('Hello')
+        '''
+        return super.txt_embedding(prompt)
+
+def create(config_path, embedding_model = False):
     '''
     create LLM instance by `config.json`
 
@@ -73,4 +91,4 @@ def create(config_path):
     -------
     >>> qwen = llm.create('./qwen-1.8b-int4/config.json')
     '''
-    return _F.create(config_path)
+    return _F.create(config_path, embedding_model)

+ 3 - 3
pymnn/pip_package/build_deps.py

@@ -99,7 +99,7 @@ def build_deps():
     if IS_WINDOWS:
         os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\
             -DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\
-            -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNConvertDeps')
+            -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_BUILD_AUDIO=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNConvertDeps')
     elif IS_LINUX:
         extra_opts += '-DMNN_TENSORRT=ON \
         -DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' '
@@ -113,14 +113,14 @@ def build_deps():
         os.system('cmake ' + extra_opts +
             '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
-             .. && make MNN MNNTrain MNNConvertDeps -j32')
+            -DMNN_BUILD_AUDIO=ON  .. && make MNN MNNTrain MNNConvertDeps -j32')
     else:
         extra_opts += ' -DMNN_INTERNAL=ON ' if USE_INTERNAL else ' '
         extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' '
         print(extra_opts)
         os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
             -DMNN_BUILD_SHARED_LIBS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\
-            -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
+            -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_BUILD_AUDIO=ON\
             .. && make MNN MNNConvertDeps -j64')
 ################################################################################
 # Building dependent libraries

+ 4 - 1
pymnn/pip_package/setup.py

@@ -166,7 +166,7 @@ def configure_extension_build():
         ]
         if check_env_flag('WERROR'):
             extra_compile_args.append('-Werror')
-    extra_compile_args += ['-DPYMNN_EXPR_API', '-DPYMNN_NUMPY_USABLE', '-DPYMNN_OPENCV_API']
+    extra_compile_args += ['-DPYMNN_EXPR_API', '-DPYMNN_NUMPY_USABLE', '-DPYMNN_OPENCV_API', '-DPYMNN_AUDIO_API']
     if IS_LINUX and USE_INTERNAL:
         extra_compile_args += ['-DPYMNN_INTERNAL_SERVING']
         if args.env == 'daily':
@@ -177,6 +177,7 @@ def configure_extension_build():
     engine_library_dirs = [os.path.join(root_dir, BUILD_DIR)]
     engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "train")]
     engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "cv")]
+    engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "audio")]
     engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "source", "backend", "tensorrt")]
     engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "source", "backend", "cuda")]
     if USE_TRT or USE_CUDA:
@@ -214,6 +215,8 @@ def configure_extension_build():
         engine_include_dirs += [os.path.join(root_dir, "3rd_party", "rapidjson")]
     # cv include
     engine_include_dirs += [os.path.join(root_dir, "tools", "cv", "include")]
+    # audio include
+    engine_include_dirs += [os.path.join(root_dir, "tools", "audio", "include")]
     # llm include
     engine_include_dirs += [os.path.join(root_dir, "transformers", "llm", "engine", "include")]
     engine_include_dirs += [os.path.join(root_dir, "3rd_party")]

+ 17 - 1
pymnn/src/MNN.cc

@@ -22,6 +22,9 @@ using namespace MNN::Express;
 #ifdef PYMNN_OPENCV_API
 #include "cv/cv.hpp"
 #endif
+#ifdef PYMNN_AUDIO_API
+#include "audio/audio.hpp"
+#endif
 #endif // PYMNN_EXPR_API
 
 #ifdef BUILD_OPTYPE
@@ -64,6 +67,9 @@ using RegularizationMethod = ParameterOptimizer::RegularizationMethod;
 #ifdef PYMNN_OPENCV_API
 #include "cv.h"
 #endif
+#ifdef PYMNN_AUDIO_API
+#include "audio.h"
+#endif
 #endif
 
 #ifdef PYMNN_LLM_API
@@ -1587,7 +1593,8 @@ static PyObject* PyMNNTensor_repr(PyObject *self) {
 #ifdef PYMNN_NUMPY_USABLE
     auto content = PyMNNTensor_getNumpyData(((PyMNNTensor*)self), NULL);
 #else
-    auto content = PyMNNVar_read_as_tuple((PyMNNVar*)self, NULL);
+    // print shape of tensor
+    auto content = PyMNNTensor_getShape((PyMNNTensor*)self, NULL);
 #endif
     auto reprfunc = PyObject_GetAttrString(content, "__repr__");
     auto str = PyEval_CallObject(reprfunc, NULL);
@@ -2713,6 +2720,15 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
         def_method(cv_module, &PyMNNCV_methods[i]);
     }
 #endif
+#ifdef PYMNN_AUDIO_API
+    // audio submodule
+    auto audio_module = def_submodule(m, "audio");
+    // add methods of audio
+    constexpr int audio_method_num = sizeof(PyMNNAUDIO_methods) / sizeof(PyMethodDef);
+    for (int i = 0; i < audio_method_num; i++) {
+        def_method(audio_module, &PyMNNAUDIO_methods[i]);
+    }
+#endif
 #endif
 #ifdef PYMNN_LLM_API
     // llm submodule

+ 105 - 0
pymnn/src/audio.h

@@ -0,0 +1,105 @@
+// MNN AUDIO
+static PyObject *PyMNNAUDIO_load(PyObject *self, PyObject *args) {
+    const char *filename = NULL;
+    int sr = 0, frame_offset = 0, num_frames = -1;
+    if (PyArg_ParseTuple(args, "s|iii", &filename, &sr, &frame_offset, &num_frames) && filename) {
+        return toPyObj<VARP, toPyObj, int, toPyObj>(AUDIO::load(filename, sr, frame_offset, num_frames));
+    }
+    PyMNN_ERROR("load require args: (string, int, int, int)");
+}
+static PyObject *PyMNNAUDIO_save(PyObject *self, PyObject *args) {
+    const char *filename = NULL;
+    PyObject *audio      = nullptr;
+    int sample_rate      = 0;
+    if (PyArg_ParseTuple(args, "sOi", &filename, &audio, &sample_rate) && filename && isVar(audio)) {
+        return toPyObj(AUDIO::save(filename, toVar(audio), sample_rate));
+    }
+    PyMNN_ERROR("save require args: (string, Var, int)");
+}
+static PyObject *PyMNNAUDIO_hamming_window(PyObject *self, PyObject *args) {
+    int window_size = 0, periodic = 0;
+    float alpha = 0.54, beta = 0.46;
+    if (PyArg_ParseTuple(args, "i|iff", &window_size, &periodic, &alpha, &beta)) {
+        return toPyObj(AUDIO::hamming_window(window_size, periodic, alpha, beta));
+    }
+    PyMNN_ERROR("hamming_window require args: (int, |bool, float, float)");
+}
+static PyObject *PyMNNAUDIO_hann_window(PyObject *self, PyObject *args) {
+    int window_size = 0, periodic = 0;
+    if (PyArg_ParseTuple(args, "i|i", &window_size, &periodic)) {
+        return toPyObj(AUDIO::hann_window(window_size, periodic));
+    }
+    PyMNN_ERROR("hann_window require args: (int, |bool)");
+}
+static PyObject *PyMNNAUDIO_melscale_fbanks(PyObject *self, PyObject *args) {
+    AUDIO::MelscaleParams mel;
+    if (PyArg_ParseTuple(args, "ii|ifff", &mel.n_mels, &mel.n_fft, &mel.sample_rate, &mel.htk, &mel.norm, &mel.f_min, &mel.f_max)) {
+        return toPyObj(AUDIO::melscale_fbanks(&mel));
+    }
+    PyMNN_ERROR("melscale_fbanks require args: (int, int, |int, bool, bool, float, float)");
+}
+static PyObject *PyMNNAUDIO_spectrogram(PyObject *self, PyObject *args) {
+    PyObject *waveform = nullptr;
+    AUDIO::SpectrogramParams spec;
+    if (PyArg_ParseTuple(args, "O|iiiiiiiiiif", &waveform, &spec.n_fft, &spec.hop_length, &spec.win_length,
+                         &spec.window_type, &spec.pad_left, &spec.pad_right, &spec.center, &spec.normalized,
+                         &spec.pad_mode, &spec.power) &&
+        isVar(waveform)) {
+        return toPyObj(AUDIO::spectrogram(toVar(waveform), &spec));
+    }
+    PyMNN_ERROR("spectrogram require args: (Var, |int, int, int, int, int, int, bool, bool, PadValueMode, float)");
+}
+static PyObject *PyMNNAUDIO_mel_spectrogram(PyObject *self, PyObject *args) {
+    PyObject *waveform = nullptr;
+    AUDIO::MelscaleParams mel;
+    AUDIO::SpectrogramParams spec;
+    int n_fft = 400;
+    if (PyArg_ParseTuple(args, "O|iiifiiifiiiii", &waveform, &mel.n_mels, &mel.n_fft, &mel.sample_rate, &mel.htk,
+                         &mel.norm, &mel.f_min, &mel.f_max, &spec.hop_length, &spec.win_length, &spec.window_type,
+                         &spec.pad_left, &spec.pad_right, &spec.center, &spec.normalized, &spec.pad_mode,
+                         &spec.power) &&
+        isVar(waveform)) {
+        spec.n_fft = mel.n_fft;
+        return toPyObj(AUDIO::mel_spectrogram(toVar(waveform), &mel, &spec));
+    }
+    PyMNN_ERROR(
+        "mel_spectrogram require args: (Var, |int, bool, bool, float, float, int, int, int, int, int, bool, bool, "
+        "PadValueMode, float)"
+        "int)");
+}
+static PyObject *PyMNNAUDIO_fbank(PyObject *self, PyObject *args) {
+    PyObject *waveform = nullptr;
+    int sample_rate = 16000, n_mels = 80, n_fft = 400, hop_length = 160;
+    float dither = 0.0, preemphasis = 0.97;
+    if (PyArg_ParseTuple(args, "O|iiiiff", &waveform, &sample_rate, &n_mels, &n_fft, &hop_length, &dither,
+                         &preemphasis) &&
+        isVar(waveform)) {
+        return toPyObj(
+            AUDIO::fbank(toVar(waveform), sample_rate, n_mels, n_fft, hop_length, dither, preemphasis));
+    }
+    PyMNN_ERROR("fbank require args: (Var, |int, int, int, int, float, float)");
+}
+
+static PyObject *PyMNNAUDIO_whisper_fbank(PyObject *self, PyObject *args) {
+    PyObject *waveform = nullptr;
+    int sample_rate = 16000, n_mels = 128, n_fft = 400, hop_length = 160, chunk_len = 30;
+    if (PyArg_ParseTuple(args, "O|iiiii", &waveform, &sample_rate, &n_mels, &n_fft, &hop_length, &chunk_len) &&
+        isVar(waveform)) {
+        return toPyObj(AUDIO::whisper_fbank(toVar(waveform), sample_rate, n_mels, n_fft, hop_length, chunk_len));
+    }
+    PyMNN_ERROR("whisper_fbank require args: (Var, |int, int, int, int, int)");
+}
+
+static PyMethodDef PyMNNAUDIO_methods[] = {
+    register_methods(AUDIO,
+        load, "load",
+        save, "save",
+        hamming_window, "hamming_window",
+        hann_window, "hann_window",
+        melscale_fbanks, "melscale_fbanks",
+        spectrogram, "spectrogram",
+        mel_spectrogram, "mel_spectrogram",
+        fbank, "fbank",
+        whisper_fbank, "whisper_fbank"
+    )
+};

+ 39 - 2
pymnn/src/llm.h

@@ -4,6 +4,7 @@
 typedef struct {
     PyObject_HEAD
     MNN::Transformer::Llm* llm;
+    bool is_embedding = false;
 } LLM;
 
 static PyObject* PyMNNLLM_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
@@ -25,6 +26,9 @@ static PyObject* PyMNNLLM_load(LLM *self, PyObject *args) {
 }
 
 static PyObject* PyMNNLLM_forward(LLM *self, PyObject *args) {
+    if (self->is_embedding) {
+        Py_RETURN_NONE;
+    }
     PyObject *input_ids = nullptr;
     if (!PyArg_ParseTuple(args, "O", &input_ids) && isInts(input_ids)) {
         Py_RETURN_NONE;
@@ -37,6 +41,9 @@ static PyObject* PyMNNLLM_forward(LLM *self, PyObject *args) {
 }
 
 static PyObject* PyMNNLLM_generate(LLM *self, PyObject *args) {
+    if (self->is_embedding) {
+        Py_RETURN_NONE;
+    }
     PyObject *input_ids = nullptr;
     if (!PyArg_ParseTuple(args, "O", &input_ids) && isInts(input_ids)) {
         Py_RETURN_NONE;
@@ -46,6 +53,9 @@ static PyObject* PyMNNLLM_generate(LLM *self, PyObject *args) {
 }
 
 static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) {
+    if (self->is_embedding) {
+        Py_RETURN_NONE;
+    }
     const char* query = NULL;
     int stream = 0;
     if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) {
@@ -57,6 +67,9 @@ static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) {
 }
 
 static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) {
+    if (self->is_embedding) {
+        Py_RETURN_NONE;
+    }
     const char* prompt = NULL;
     int use_template = 0;
     if (!PyArg_ParseTuple(args, "s|p", &prompt, &use_template)) {
@@ -67,6 +80,9 @@ static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) {
 }
 
 static PyObject* PyMNNLLM_tokenizer_decode(LLM *self, PyObject *args) {
+    if (self->is_embedding) {
+        Py_RETURN_NONE;
+    }
     PyObject *id = nullptr;
     if (!PyArg_ParseTuple(args, "O", &id) && isInt(id)) {
         Py_RETURN_NONE;
@@ -75,6 +91,19 @@ static PyObject* PyMNNLLM_tokenizer_decode(LLM *self, PyObject *args) {
     return string2Object(query);
 }
 
+static PyObject* PyMNNLLM_txt_embedding(LLM *self, PyObject *args) {
+    if (!self->is_embedding) {
+        Py_RETURN_NONE;
+    }
+    const char* query = NULL;
+    if (!PyArg_ParseTuple(args, "s", &query)) {
+        Py_RETURN_NONE;
+    }
+    auto embeds = getVar();
+    *(embeds->var) = ((MNN::Transformer::Embedding*)self->llm)->txt_embedding(query);
+    return (PyObject *)embeds;
+}
+
 static PyMethodDef PyMNNLLM_methods[] = {
     {"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."},
     {"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."},
@@ -82,6 +111,7 @@ static PyMethodDef PyMNNLLM_methods[] = {
     {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."},
     {"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."},
     {"tokenizer_decode", (PyCFunction)PyMNNLLM_tokenizer_decode, METH_VARARGS, "tokenizer decode."},
+    {"txt_embedding", (PyCFunction)PyMNNLLM_txt_embedding, METH_VARARGS, "txt embedding."},
     {NULL}  /* Sentinel */
 };
 
@@ -131,14 +161,21 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) {
         return NULL;
     }
     const char* path = NULL;
-    if (!PyArg_ParseTuple(args, "s", &path)) {
+    int embedding_model = 0;
+    if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) {
         return NULL;
     }
     LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL);
     if (!llm) {
         return NULL;
     }
-    llm->llm = MNN::Transformer::Llm::createLLM(path);
+    if (embedding_model) {
+        llm->llm = MNN::Transformer::Embedding::createEmbedding(path);
+        llm->is_embedding = true;
+    } else {
+        llm->llm = MNN::Transformer::Llm::createLLM(path);
+    }
+
     return (PyObject*)llm;
 }
 

File diff suppressed because it is too large
+ 192 - 12
schema/current/MNN_generated.h


+ 9 - 1
schema/default/MNN.fbs

@@ -168,6 +168,7 @@ enum OpType : int {
     Svd = 153,
     Histogram = 154,
     DynamicQuant = 155,
+    Stft = 156,
 
     Plugin = 256, //The Type load from plugin
     //Training Op Start from 257
@@ -239,6 +240,12 @@ table FmhcaParam {
     heads: int;
 }
 
+table StftParam {
+    n_fft: int;
+    hop_length: int;
+    abs: bool = true;
+}
+
 table WhileParam {
     // The name of condition subgraph.
     cond_graph: string;
@@ -414,7 +421,8 @@ union OpParameter {
     GroupNorm,
     FmhaV2Param,
     FmhcaParam,
-    AttentionParam
+    AttentionParam,
+    StftParam
 }
 
 table Op {

+ 3 - 0
source/backend/cpu/CPUBackend.hpp

@@ -237,6 +237,9 @@ private:
         CPUBackend::addCreator(opType, &_temp); \
     }
 
+#define REGISTER_CPU_OP_CREATOR_AUDIO(name, opType) \
+    REGISTER_CPU_OP_CREATOR(name, opType)
+
 } // namespace MNN
 
 #endif /* CPUBackend_hpp */

+ 8 - 8
source/backend/cpu/CPUBinaryInt8.cpp

@@ -80,16 +80,16 @@ ErrorCode CPUBinaryInt8::onExecute(const std::vector<Tensor*>& inputs, const std
 
     int inpBytes = 1;
     int outBytes = 1;
+    QuanPrePostParameters params;
+    
+    params.inputScale = mInputScales.data();
+    params.outputScale = mOutputScales.data();
+    params.outputZeroPoint = mOutputZeros.data();
+    params.inputZeroPoint = mInputZeros.data();
+    params.minValue = (ssize_t)mMinValue;
+    params.maxValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->max;
 
     MNN_CONCURRENCY_BEGIN(tId, schedule.second) {
-        QuanPrePostParameters params;
-        
-        params.inputScale = mInputScales.data();
-        params.outputScale = mOutputScales.data();
-        params.outputZeroPoint = mOutputZeros.data();
-        params.inputZeroPoint = mInputZeros.data();
-        params.minValue = (ssize_t)mMinValue;
-        params.maxValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->max;
 
         int start = schedule.first * (int)tId;
         int realSize = schedule.first;

+ 120 - 254
source/backend/cpu/CPUDeconvolution.cpp

@@ -18,7 +18,6 @@
 #include "core/ConvolutionCommon.hpp"
 #include "compute/CommonOptFunction.h"
 #include "compute/ConvOpt.h"
-#include "compute/DeconvolutionWithStride.hpp"
 //#define MNN_OPEN_TIME_TRACE
 #include <MNN/AutoTime.hpp>
 
@@ -83,63 +82,13 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu
     //printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
     core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
 }
-// Int8 Weight.
-static void _reorderWeightInt8(Backend* bn, const Convolution2DCommon* common, const int8_t* srcPtr,
-                               std::shared_ptr<Tensor>& weight) {
-    auto core = static_cast<CPUBackend*>(bn)->int8Functions();
-    auto gcore =  static_cast<CPUBackend*>(bn)->functions();
-    int UNIT, SRC_UNIT, DST_XUNIT;
-    core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    UNIT = gcore->pack;
 
-    int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY();
-    std::vector<int> shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT};
-
-    weight.reset(Tensor::createDevice<int8_t>(shape));
-    bool succ = bn->onAcquireBuffer(weight.get(), Backend::STATIC);
-    if (!succ) {
-        MNN_ERROR("Memory not enough");
-        return;
-    }
-    auto dstPtr = weight->host<int8_t>();
-    ::memset(dstPtr, 0, weight->size());
-
-    int icDiv = UP_DIV(ic, SRC_UNIT);
-     for (int k = 0; k < kernelCount; ++k) {
-        auto srcK = srcPtr + k;
-        auto dstK = dstPtr + k * SRC_UNIT * UNIT * icDiv;
-        for (int x = 0; x < oc; ++x) {
-            int xout = x / UNIT;
-            int xin = x % UNIT;
-            auto srcY = srcK + x * kernelCount;
-            auto dstY = dstK + xout * SRC_UNIT * UNIT * icDiv * kernelCount + xin * SRC_UNIT;
-            for (int y = 0; y < ic; ++y) {
-                int yout = y / SRC_UNIT;
-                int yin = y % SRC_UNIT;
-
-                const int dstIndex = yout * SRC_UNIT * UNIT + yin;
-                const int srcIndex = y * oc * kernelCount;
-                dstY[dstIndex] = srcY[srcIndex];
-            }
-        }
-    }
-}
 CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend, bool dynamicWeight)
     : MNN::CPUDeconvolutionCommon(input, convOp, backend, dynamicWeight) {
     auto core               = static_cast<CPUBackend*>(backend)->functions();
     auto coreInt8           = static_cast<CPUBackend*>(backend)->int8Functions();
     int eP, lP, hP;
     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
-    int UNIT, SRC_UNIT, DST_XUNIT;
-    coreInt8->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-    bool ModeInt8        =  false;
-
-    if (CPUBackend::getDataType(input) == DataType_DT_INT8 || input->getType().bytes() == 1) {
-        eP = DST_XUNIT;
-        lP = SRC_UNIT;
-        hP = UNIT;
-        ModeInt8 = true;
-    }
     auto conv2d                  = convOp->main_as_Convolution2D();
     auto layer                   = conv2d->common();
     int outputCount              = layer->outputCount();
@@ -155,30 +104,17 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
     mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
     std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount}));
     if (dynamicWeight) {
-        mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, ModeInt8));
+        mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false));
         mWeightTransformCache = cache;
         return;
     }
 
     const float* tempWeight      = nullptr;
-    const int8_t* quanWeightInt8 = nullptr;
 
     int tempWeightSize   = 0;
-    std::unique_ptr<Tensor> externalWeightTensor;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
 
-    std::vector<int32_t> _bias(outputChannleUp4, 0);
-    std::vector<float> _scale(outputChannleUp4, 0);
-    std::vector<int32_t> _beta(outputChannleUp4, 0);
-    auto biasPtr = _bias.data();
-    auto scalePtr = _scale.data();
-    auto betaPtr = _beta.data();
-
-    if (ModeInt8) {
-        ConvolutionCommon::getConvInt8Parameters(convOp, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr, betaPtr);
-    } else {
-        ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize);
-    }
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize);
 
     bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) &&
                    backend->onAcquireBuffer(cache.get(), Backend::STATIC);
@@ -196,26 +132,16 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
         core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw);
         tempWeight = (float*)lowpWeight.get();
     }
-    if (!ModeInt8) {
-        mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
-        success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC);
-        if (!success) {
-            mValid = false;
-            return;
-        }
-        auto dest = mWeight->host<uint8_t>();
-        _transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host<uint8_t>(), core);
-    } else {
-        mWeight.reset(Tensor::createDevice<int8_t>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
-        success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC);
-        if (!success) {
-            mValid = false;
-            return;
-        }
-        _reorderWeightInt8(backend, layer, quanWeightInt8, mWeight);
+    mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
+    success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC);
+    if (!success) {
+        mValid = false;
+        return;
     }
+    auto dest = mWeight->host<uint8_t>();
+    _transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host<uint8_t>(), core);
     backend->onReleaseBuffer(cache.get(), Backend::STATIC);
-    mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, ModeInt8));
+    mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false));
 }
 
 CPUDeconvolution::~CPUDeconvolution() {
@@ -261,68 +187,21 @@ ErrorCode CPUDeconvolution::onResize(const std::vector<Tensor *> &inputs, const
 }
 
 CPUDeconvolutionOrigin::CPUDeconvolutionOrigin(const Tensor *input, Tensor *weight, const Op *convOp, Backend *b, bool ModeInt8) : CPUDeconvolutionBasic(input, convOp, b) {
-    if (ModeInt8) {
-        const auto weightDataPtr = weight->host<int8_t>();
-        auto conv2d = convOp->main_as_Convolution2D();
-        auto common = conv2d->common();
-        auto pack = static_cast<CPUBackend*>(b)->functions()->pack;
-        mResource = CPUConvolution::makeResourceInt8(backend(), convOp, pack);
-        CPUConvolution::MutableResourceInt8 mutableResource(mResource, b);
-        auto core = static_cast<CPUBackend*>(b)->int8Functions();
-        auto gemmKernel = core->Int8GemmKernel;
-        int UNIT, SRC_UNIT, DST_XUNIT;
-        core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
-        const auto kEleCnt = mCommon->kernelX() * mCommon->kernelY();
-        const int ocDiv4 = UP_DIV(common->outputCount(), pack) * kEleCnt;
-        const int icDiv4 = UP_DIV(common->inputCount(), SRC_UNIT);
-        const int ocDivUnit = UP_DIV(common->outputCount(), UNIT);
-        const int oc4 = ocDiv4 / kEleCnt;
-        const int bias_elesize = ocDiv4 * pack;
-        // set offset if use SSE.
-        auto inputQuant = TensorUtils::getQuantInfo(input);
-        auto inputZeroPoint = inputQuant[1];
-        std::vector<int32_t> _bias(bias_elesize, inputZeroPoint);
-#ifdef MNN_USE_SSE
-        int actBits = conv2d->symmetricQuan()->nbits();
-        if (actBits <= 7) {
-            gemmKernel = core->Int8GemmKernelFast;
-        }
-        for (int a = 0; a < kEleCnt; ++a){
-            for (int oz = 0; oz < ocDivUnit * UNIT; ++oz) {
-                int offset = inputZeroPoint, oz4 = oz / UNIT, ozRemain = oz % UNIT;
-                for (int sz = 0; sz < icDiv4 * SRC_UNIT; ++sz) {
-                    int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT;
-                    int index = (((a * oc4 + oz4) * icDiv4 + sz4) * UNIT + ozRemain) * SRC_UNIT + szRemain;
-                    auto weightInt8Data = weightDataPtr[index];
-                    offset += weightInt8Data * (-128);
-                }
-                if (oz < oc4 * pack) {
-                    _bias[a * oc4 * pack + oz] = offset;
-                }
-            }
-        }
-#else
-        if(conv2d->symmetricQuan() && conv2d->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){
-            gemmKernel = core->Int8GemmKernelFast;
-        }
-#endif
-        mDeconvInt8Exe.reset(new GemmInt8Executor(b, mResource, convOp, gemmKernel, _bias));
-    }
+    // Do nothing
 }
 
 ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
     CPUDeconvolutionBasic::onResize(inputs, outputs);
     auto core = static_cast<CPUBackend*>(backend())->functions();
-    auto gcore = static_cast<CPUBackend*>(backend())->int8Functions();
     int bytes = core->bytes;
     auto input  = inputs[0];
     auto output = outputs[0];
     auto oc     = output->channel();
-    int UNIT, SRC_UNIT, DST_XUNIT;
-    gcore->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
     if (UP_DIV(oc, core->pack) * core->pack != inputs[2]->length(0)) {
         return INPUT_DATA_ERROR;
     }
+    int eP, lP, hP;
+    core->MNNGetMatMulPackMode(&eP, &lP, &hP);
 
     auto ocC4       = UP_DIV(output->channel(), core->pack);
     auto icC4       = UP_DIV(input->channel(), core->pack);
@@ -339,136 +218,132 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, c
     auto src_height = output->height();
     auto src_width  = output->width();
     auto batch      = output->batch();
+    auto weightTensor = inputs[1];
+    auto biasTensor = inputs[2];
 
     auto kernelCount = ocC4 * mCommon->kernelX() * mCommon->kernelY();
-    mPostFunctions.clear();
-    auto plane         = width * height * batch;
-    const int maxDepth = 5;
+    auto plane = width * height * batch;
     auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
-    //int zeroPoint = 0;
-
-    auto biasTensor = inputs[2];
-
-    // prepare for float2int8 if necessary.
-    auto outputQuant = TensorUtils::getQuantInfo(outputs[0]);
-    float scale = outputQuant[0];
-    scale = (scale == 0.f ? 0.f : 1.f / scale);
-    auto maxValue = outputQuant[3];
-    auto minValue = outputQuant[2];
-    auto zeroPoint = outputQuant[1];
-
-    AutoRelease<Tensor> tempInput(Tensor::createDevice<float>({icC4, plane, core->pack}));
-    bool needReleaseTempInput = true;
-    int outi8 = 0;
-    if (CPUBackend::getDataType(output) == DataType_DT_INT8 || output->getType().bytes() == 1) {
-        outi8 = 1;
+    auto threadNumber = static_cast<CPUBackend*>(backend())->threadNumber();
+    auto tileCount = UP_DIV(plane, eP);
+    threadNumber = ALIMIN(tileCount, threadNumber);
+    auto im2colOutputStride = input->channel() * eP * core->bytes;
+    mGemmInput = allocator->alloc(threadNumber * im2colOutputStride);
+    auto gemmOutputStride = kernelCount * core->pack * eP * core->bytes;
+    mGemmOutput = allocator->alloc(threadNumber * gemmOutputStride);
+    auto outputSize = batch*src_width*src_height*ocC4*core->pack*core->bytes;
+    if (threadNumber > 1) {
+        mExtraOutput = allocator->alloc((threadNumber-1)*outputSize);
     }
-    if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) {
-        mTempOutput.reset(Tensor::createDevice<float>({batch, height, width, ocC4 * kw * kh * core->pack}));
-        auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC);
-        if (!res) {
-            return OUT_OF_MEMORY;
-        }
-        mDeconvInt8Exe->onResize({input}, {mTempOutput.get()});
-        if (mResource->mRelu) {
-            minValue = outputQuant[1];
-        }
+    allocator->free(mGemmInput);
+    allocator->free(mGemmOutput);
+    if (threadNumber > 1) {
+        allocator->free(mExtraOutput);
     }
-    else {
-        mTempOutput.reset(Tensor::createDevice<float>({kernelCount, plane, core->pack}));
-        auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC);
-        if (!res) {
-            return OUT_OF_MEMORY;
-        }
-        mMatMul.reset(new StrassenMatrixComputor(backend(), true, maxDepth));
-        // tempInput->buffer().host = (uint8_t*)inputPtr;
-
-        needReleaseTempInput = false;
-        TensorUtils::getDescribeOrigin(tempInput.get())->mem = new CPUMemObj(nullptr, TensorUtils::getDescribeOrigin(input)->mem->chunk(), 0);
-        mMatMul->onEncode({tempInput.get(), inputs[1]}, {mTempOutput.get()});
-    }
-    auto threadNumber = ((CPUBackend*)backend())->threadNumber();
-    std::vector<float> scales(core->pack * src_height * src_width * batch, scale);
-    MemChunk outputFp32Ptr;
-    if (outi8) {
-        outputFp32Ptr = allocator->alloc(batch * src_height * src_width * ocC4 * core->pack * bytes);
-        if (outputFp32Ptr.invalid()) {
-            return OUT_OF_MEMORY;
-        }
-    }
-
-    mPostFunctions.emplace_back(std::make_pair([ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY,
-                       strideX, threadNumber, src_width, src_height, plane, input, biasTensor, this, core, gcore, batch, outi8, scale,
-                       minValue, maxValue, zeroPoint, outputFp32Ptr](uint8_t* outputPtr, int tId) {
-        auto colBufferPtr = mTempOutput->host<uint8_t>();
-        auto biasPtr      = biasTensor->host<float>();
-        auto inputPtr  = input->host<float>();
+    auto first = std::make_pair([=](uint8_t* outputPtr, int tId) {
+        auto gemmInputBufferPtr = mGemmInput.ptr() + tId * im2colOutputStride;
+        auto colBufferPtr = mGemmOutput.ptr() + tId * gemmOutputStride;
+        auto inputPtr  = input->host<uint8_t>();
         auto unitBytes = core->pack * core->bytes;
         auto tempOutPtr = outputPtr;
-        auto float2Int8_step = src_height * src_width * batch;
-        if (outi8) {
-            tempOutPtr = outputFp32Ptr.ptr();
+        if (tId > 0) {
+            tempOutPtr = mExtraOutput.ptr() + (tId-1) * outputSize;
         }
-        for (int z = (tId); z < ocC4; z += threadNumber) {
-            auto dstZ = tempOutPtr + z * src_height * src_width * batch * unitBytes;
-            auto srcZ = colBufferPtr + kw * kh * plane * z * unitBytes;
-            ::memset(dstZ, 0, src_width * src_height * batch * unitBytes);
-            for (int b = 0; b < batch; ++b) {
-                auto dstB = dstZ + b * src_width  * src_height * unitBytes;
-                auto srcB = srcZ + b * width * height * unitBytes;
-                for (int oy = 0; oy < height; ++oy) {
-                    for (int ox = 0; ox < width; ++ox) {
-                        int srcStartX = ox * strideX - padX;
-                        int srcStartY = oy * strideY - padY;
-
-                        int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY)));
-                        int efy = ALIMIN(kh, UP_DIV(src_height - srcStartY, dilateY));
-
-                        int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX)));
-                        int efx = ALIMIN(kw, UP_DIV(src_width - srcStartX, dilateX));
-
-                        auto dstStart = dstB + srcStartX * unitBytes + srcStartY * src_width * unitBytes;
-                        auto srcStart = srcB + unitBytes * (ox + oy * width);
-                        if (sfy >= efy || sfx >= efx) {
-                            continue;
-                        }
-
-                        for (int fy = sfy; fy < efy; ++fy) {
-                            auto dstY = dstStart + fy * unitBytes * dilateY * src_width;
-                            auto srcY = srcStart + fy * kw * plane * unitBytes;
-                            core->MNNAddC4WithStride((const float*)(srcY + sfx * plane * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), plane * core->pack, dilateX * core->pack, efx - sfx);
-                        }
+        ::memset(tempOutPtr, 0, outputSize);
+
+        int l = mSrcCount;
+        int h = kernelCount * core->pack;
+        auto weightPtr = weightTensor->host<uint8_t>();
+        for (int index=tId; index < tileCount; index+=threadNumber) {
+            int xStart = index * eP;
+            int xEnd = ALIMIN(xStart + eP, plane);
+            int xCount = xEnd-xStart;
+            if (xCount <= 0) {
+                continue;
+            }
+            size_t parameters[7];
+            parameters[0] = xCount * core->bytes;
+            parameters[1] = l;
+            parameters[2] = h;
+            parameters[3] = xCount * core->bytes * core->pack;
+            parameters[4] = 0;
+            parameters[5] = 0;
+            parameters[6] = 0;
+            const float* postParametersPtr = nullptr;
+            int32_t info[4];
+            int32_t stride[4];
+            stride[0] = xCount;
+            stride[1] = (int32_t)parameters[1];
+            stride[2] = 0;
+            stride[3] = 0;
+            info[0] = 1;
+            info[1] = plane;
+            info[2] = xCount;
+            info[3] = 1;
+            auto aStart = inputPtr + xStart * unitBytes;
+            core->MNNPackC4ForMatMul_A((float*)(gemmInputBufferPtr), (const float**)(&aStart), info, stride);
+            if (xCount == eP) {
+                core->MNNPackedMatMul((float*)(colBufferPtr), (float*)gemmInputBufferPtr, (float*)weightPtr, parameters, postParametersPtr, nullptr, nullptr, nullptr);
+            } else {
+                core->MNNPackedMatMulRemain((float*)(colBufferPtr), (float*)gemmInputBufferPtr, (float*)weightPtr, xCount, parameters, postParametersPtr, nullptr, nullptr, nullptr);
+            }
+            // Col2Im
+            for (int z = 0; z < ocC4; ++z) {
+                auto dstZ = tempOutPtr + z * src_height * src_width * batch * unitBytes;
+                auto srcZ = colBufferPtr + kw * kh * xCount * z * unitBytes;
+                for (int x=0; x<xCount; ++x) {
+                    auto index = xStart + x;
+                    int b = index / (width * height);
+                    index = index % (width * height);
+                    int oy = index / width;
+                    int ox = index % width;
+                    int srcStartX = ox * strideX - padX;
+                    int srcStartY = oy * strideY - padY;
+                    
+                    int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY)));
+                    int efy = ALIMIN(kh, UP_DIV(src_height - srcStartY, dilateY));
+                    
+                    int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX)));
+                    int efx = ALIMIN(kw, UP_DIV(src_width - srcStartX, dilateX));
+                    
+                    auto dstStart = dstZ + b * src_width * src_height * unitBytes + srcStartX * unitBytes + srcStartY * src_width * unitBytes;
+                    auto srcStart = srcZ + x * unitBytes;
+                    if (sfy >= efy || sfx >= efx) {
+                        continue;
+                    }
+                    
+                    for (int fy = sfy; fy < efy; ++fy) {
+                        auto dstY = dstStart + fy * unitBytes * dilateY * src_width;
+                        auto srcY = srcStart + fy * kw * xCount * unitBytes;
+                        core->MNNAddC4WithStride((const float*)(srcY + sfx * xCount * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), xCount * core->pack, dilateX * core->pack, efx - sfx);
                     }
                 }
             }
-            core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr +  unitBytes * z), src_height * src_width * batch, 0, 0, 1, mPostParameters.data());
-            if (outi8) {
-                float scaleOne = scale;
-                float zeroOne  = zeroPoint;
-                gcore->MNNFloat2Int8((float*)dstZ, (int8_t*)(outputPtr + z * float2Int8_step * core->pack), float2Int8_step, &scaleOne, minValue, maxValue, &zeroOne, 0);
+        }
+    }, threadNumber);
+    auto second = std::make_pair([ocC4, src_height, src_width, threadNumber, batch, biasTensor, this, outputSize, core](uint8_t* outputPtr, int tId) {
+        auto unitBytes = core->pack * core->bytes;
+        auto biasPtr = biasTensor->host<uint8_t>();
+        for (int z = tId; z < ocC4; z+=threadNumber) {
+            auto dstZ = outputPtr + z * src_height * src_width * batch * unitBytes;
+            if (threadNumber > 1) {
+                for (int index=0; index<threadNumber-1; ++index) {
+                    auto src = mExtraOutput.ptr() + index * outputSize + z * src_height * src_width * batch * unitBytes;
+                    core->MNNMatrixAdd((float*)(dstZ), (float*)(src), (float*)(dstZ), src_height * src_width * batch, 0, 0, 0, 1);
+                }
             }
+            core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr +  unitBytes * z), src_height * src_width * batch, 0, 0, 1, mPostParameters.data());
         }
-    }, threadNumber));
-    if (outi8) {
-        allocator->free(outputFp32Ptr);
-    }
-    if (needReleaseTempInput) {
-        backend()->onReleaseBuffer(tempInput.get(), Backend::DYNAMIC);
-    }
-    backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC);
+
+    }, threadNumber);
+    mExecuteFuntion = {first, second};
     return NO_ERROR;
 }
 
 ErrorCode CPUDeconvolutionOrigin::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
     auto inputPtr = inputs[0]->host<uint8_t>();
     auto outputPtr = outputs[0]->host<uint8_t>();
-    if (mDeconvInt8Exe.get() != nullptr) {
-        mDeconvInt8Exe->onExecute({inputs[0], inputs[1]}, {mTempOutput.get()});
-    }
-    else {
-        mMatMul->onExecute();
-    }
-    for (auto& unit : mPostFunctions) {
+    for (auto& unit : mExecuteFuntion) {
         MNN_CONCURRENCY_BEGIN(tId, unit.second) {
             unit.first(outputPtr, (int)tId);
         }
@@ -482,15 +357,6 @@ public:
                                 const MNN::Op* op, Backend* backend) const {
         auto convOp = op->main_as_Convolution2D();
         auto common = convOp->common();
-        if (backend->type() == MNN_FORWARD_CPU && inputs.size() == 1) {
-            if (common->strideY() > 1 || common->strideX() > 1) {
-                if (common->dilateX() == 1 && common->dilateY() == 1) {
-                    if (common->kernelX() / common->strideX() > 2 || common->kernelY() / common->strideY() > 2) {
-                        return new DeconvolutionWithStride(inputs[0], op, backend);
-                    }
-                }
-            }
-        }
         return new CPUDeconvolution(inputs[0], op, backend, inputs.size() > 1);
     }
 };

+ 5 - 6
source/backend/cpu/CPUDeconvolution.hpp

@@ -12,7 +12,6 @@
 #include "CPUConvolution.hpp"
 #include "compute/CommonOptFunction.h"
 #include "compute/StrassenMatmulComputor.hpp"
-#include "compute/GemmInt8Executor.hpp"
 #include "core/TensorUtils.hpp"
 namespace MNN {
 class CPUDeconvolutionBasic : public CPUConvolution {
@@ -44,11 +43,11 @@ public:
     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
 
 private:
-    std::shared_ptr<StrassenMatrixComputor> mMatMul;
-    std::shared_ptr<GemmInt8Executor> mDeconvInt8Exe;
-    std::vector<std::pair<std::function<void(uint8_t*, int)>, int>> mPostFunctions;
-    std::shared_ptr<Tensor> mTempOutput;
-    std::shared_ptr<CPUConvolution::ResourceInt8> mResource;
+    MemChunk mGemmOutput;
+    MemChunk mGemmInput;
+    MemChunk mExtraOutput;
+
+    std::vector<std::pair<std::function<void(uint8_t*, int)>, int>> mExecuteFuntion;
 };
 
 class CPUDeconvolution : public CPUDeconvolutionCommon {

+ 7 - 4
source/backend/cpu/CPUInstanceNorm.cpp

@@ -6,9 +6,10 @@
 //  Copyright © 2018, Alibaba Group Holding Limited
 //
 
+#include "backend/cpu/CPUBackend.hpp"
+#ifdef MNN_SUPPORT_DEPRECATED_OP
 #include "backend/cpu/CPUInstanceNorm.hpp"
 #include <math.h>
-#include "backend/cpu/CPUBackend.hpp"
 #include "core/Concurrency.h"
 #include <MNN/MNNDefine.h>
 #include "core/Macro.h"
@@ -106,7 +107,9 @@ public:
         return new CPUInstanceNorm(backend, op);
     }
 };
-
-REGISTER_CPU_OP_CREATOR(CPUInstanceNormCreator, OpType_InstanceNorm);
-
 } // namespace MNN
+#endif
+namespace MNN {
+REGISTER_CPU_OP_CREATOR_OLD(CPUInstanceNormCreator, OpType_InstanceNorm);
+};
+

+ 7 - 4
source/backend/cpu/CPUMoments.cpp

@@ -6,9 +6,10 @@
 //  Copyright © 2018, Alibaba Group Holding Limited
 //
 
+#include "backend/cpu/CPUBackend.hpp"
+#ifdef MNN_SUPPORT_DEPRECATED_OP
 #include "backend/cpu/CPUMoments.hpp"
 #include <math.h>
-#include "backend/cpu/CPUBackend.hpp"
 #include "core/Concurrency.h"
 #include <MNN/MNNDefine.h>
 #include "core/Macro.h"
@@ -129,7 +130,9 @@ public:
         return new CPUMoments(backend, op);
     }
 };
-
-REGISTER_CPU_OP_CREATOR(CPUMomentsCreator, OpType_Moments);
-
 } // namespace MNN
+#endif
+namespace MNN {
+REGISTER_CPU_OP_CREATOR_OLD(CPUMomentsCreator, OpType_Moments);
+};
+

+ 6 - 0
source/backend/cpu/CPUOPRegister.cpp

@@ -78,6 +78,9 @@ extern void ___CPUTextureCreator__OpType_Texture__();
 #ifdef MNN_SUPPORT_TRANSFORMER_FUSE
 extern void ___CPUAttentionCreator__OpType_Attention__();
 #endif
+#ifdef MNN_BUILD_AUDIO
+extern void ___CPUStftCreator__OpType_Stft__();
+#endif
 void registerCPUOps() {
 ___CPUCropAndResizeCreator__OpType_CropAndResize__();
 ___CPUArgMaxCreator__OpType_ArgMax__();
@@ -156,5 +159,8 @@ ___CPUTextureCreator__OpType_Texture__();
 #ifdef MNN_SUPPORT_TRANSFORMER_FUSE
 ___CPUAttentionCreator__OpType_Attention__();
 #endif
+#ifdef MNN_BUILD_AUDIO
+___CPUStftCreator__OpType_Stft__();
+#endif
 }
 }

+ 72 - 28
source/backend/cpu/CPURelu.cpp

@@ -46,16 +46,53 @@ ErrorCode CPURelu::onExecute(const std::vector<Tensor*>& inputs, const std::vect
     auto& ob = outputs[0]->buffer();
 
     if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) {
+        auto core = static_cast<CPUBackend*>(backend())->int8Functions();
+        auto gcore = static_cast<CPUBackend*>(backend())->functions();
         const int8_t* srcO = (const int8_t*)ib.host;
+        int8_t* dstO       = (int8_t*)ob.host;
         auto inInfo = TensorUtils::getQuantInfo(inputs[0]);
         auto outInfo = TensorUtils::getQuantInfo(outputs[0]);
-        if (inInfo != outInfo) {
-            MNN_PRINT("this relu int8 implementation has error when input output quant info mismatch\n");
-        }
-        int8_t zeroPoint = int8_t(outInfo[1]);
-        int8_t* dstO       = (int8_t*)ob.host;
         auto size         = mRealSize;
         auto numberThread = ((CPUBackend*)backend())->threadNumber();
+
+        auto inputscale = inInfo[0];
+        auto inputzero = (ssize_t)inInfo[1];
+        auto outputzero = (ssize_t)outInfo[1];
+        auto outputscale = outInfo[0] > 0.f ? 1.0f / outInfo[0] : 0.f;
+        QuanPrePostParameters params;
+        params.maxValue = static_cast<ssize_t>(inInfo[3]);
+        params.minValue = static_cast<ssize_t>(inInfo[2]);
+        params.inputScale = &inputscale;
+        params.inputZeroPoint = &inputzero;
+        params.outputScale = &outputscale;
+        params.outputZeroPoint = &outputzero;
+        
+        if (((float*)mSlope.get())[0] != 0.f) {
+            // PRelu Int8
+            int sizeQuad     = size / gcore->pack;
+            int remain       = size % gcore->pack;
+            int sizeDivide = UP_DIV(sizeQuad, numberThread);
+            
+            if (sizeQuad > 0) {
+                MNN_CONCURRENCY_BEGIN(tId, numberThread) {
+                    
+                    int number = sizeDivide;
+                    if (tId == numberThread - 1) {
+                        number = sizeQuad - tId * sizeDivide;
+                    }
+                    core->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + tId * gcore->pack * sizeDivide), srcO + tId * sizeDivide * gcore->pack, (const float*)(mSlope.get()), number, 1, &params, gcore->pack);
+                                    
+                }
+                MNN_CONCURRENCY_END();
+            }
+            if (remain > 0) {
+                ::memcpy(mCacheSrc.get(), srcO + sizeQuad * gcore->pack, remain);
+                core->MNNReluWithSlopeChannelInt8((int8_t*)mCacheDst.get(), (const int8_t*)(mCacheSrc.get()), (const float*)mSlope.get(), 1, 1, &params, gcore->pack);
+                ::memcpy(dstO + sizeQuad * gcore->pack, mCacheDst.get(), remain);
+            }
+            return NO_ERROR;
+        }
+        int8_t zeroPoint = int8_t(outInfo[1]);
         int sizeQuad     = size / 16;
         int remain       = sizeQuad * 16;
         int sizeDivide = sizeQuad / numberThread;
@@ -187,10 +224,6 @@ ErrorCode CPUPRelu::onResize(const std::vector<Tensor*>& inputs, const std::vect
         mQuanScalesOutput = {outputScale};
         mQuanZerosInput = {inputZero};
         mQuanZerosOutput = {outputZero};
-        auto p = mSlope.host<float>();
-        for (int i = 0; i < mSlope.buffer().dim[0].extent; ++i) {
-            p[i] = p[i] * inputScale * outputScale;
-        }
     }
     return NO_ERROR;
 }
@@ -198,42 +231,53 @@ ErrorCode CPUPRelu::onResize(const std::vector<Tensor*>& inputs, const std::vect
 ErrorCode CPUPRelu::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
     auto& ib            = inputs[0]->buffer();
     auto& ob            = outputs[0]->buffer();
-    int sizeQuad = 1;
-    for (int i=2; i<ib.dimensions; ++i) {
-        sizeQuad *= ib.dim[i].extent;
-    }
     auto core = static_cast<CPUBackend*>(backend())->functions();
     auto coreInt8 = static_cast<CPUBackend*>(backend())->int8Functions();
     const int channel   = ib.dim[1].extent;
     const int batch     = ib.dim[0].extent;
-    int pack = 4;
-    int depthQuad = UP_DIV(channel, core->pack);
-    const uint8_t* srcO   = (const uint8_t*)ib.host;
+    int pack = core->pack;
+    
+    const int8_t* srcO   = (const int8_t*)ib.host;
     uint8_t* dstO         = (uint8_t*)ob.host;
+    auto depthQuad = UP_DIV(channel, core->pack);
     auto totalCount = batch * depthQuad;
     auto numberThread = ((CPUBackend*)backend())->threadNumber();
+    auto sizeQuad = UP_DIV(depthQuad, numberThread);
+    auto sizeCount = sizeQuad * batch * inputs[0]->width() * inputs[0]->height() * core->pack;
+    
     if (mUseInt8) {
-        depthQuad = UP_DIV(channel, pack);
+        auto inputInfo = TensorUtils::getDescribe(inputs[0])->quantAttr;
+        auto outputInfo = TensorUtils::getDescribe(outputs[0])->quantAttr;
+        auto inzero = (ssize_t)inputInfo->zero;
+        auto outzero = (ssize_t)outputInfo->zero;
+        auto outscale = outputInfo->scale > 0 ? 1.f / outputInfo->scale : 0.f;
+        QuanPrePostParameters params;
+        params.maxValue = static_cast<ssize_t>(outputInfo->max);
+        params.minValue = static_cast<ssize_t>(outputInfo->min);
+        params.inputScale = &inputInfo->scale;
+        params.inputZeroPoint = &inzero;
+        params.outputScale = &outscale;
+        params.outputZeroPoint = &outzero;
         MNN_CONCURRENCY_BEGIN(tId, numberThread) {
-            QuanPrePostParameters params;
-            params.maxValue = static_cast<ssize_t>(TensorUtils::getDescribe(inputs[0])->quantAttr->max);
-            params.minValue = static_cast<ssize_t>(TensorUtils::getDescribe(inputs[0])->quantAttr->min);
-            params.inputScale = mQuanScalesInput.data();
-            params.inputZeroPoint = mQuanZerosInput.data();
-            params.outputScale = mQuanScalesOutput.data();
-            params.outputZeroPoint = mQuanZerosOutput.data();
-            for (int b=tId; b<totalCount; b+=numberThread) {
-                auto c = b / batch;
-                coreInt8->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + sizeQuad * pack * b), (const int8_t*)(srcO + sizeQuad * pack * b), (const float*)(mSlope.host<uint8_t>() + core->bytes * pack * c), sizeQuad, 1, &params);
+            
+            
+            auto number = ALIMIN(sizeQuad, depthQuad - tId * sizeQuad);
+            if (number > 0) {
+                auto sizeQ = number * batch * inputs[0]->width() * inputs[0]->height();
+                coreInt8->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + tId * sizeCount), srcO + tId * sizeCount, (const float*)(mSlope.host<uint8_t>() + tId * sizeQuad * pack * core->bytes), sizeQ / number, number, &params, core->pack);
             }
         }
         MNN_CONCURRENCY_END();
         return NO_ERROR;
     }
+    int hw = 1;
+    for (int i=2; i<ib.dimensions; ++i) {
+        hw *= ib.dim[i].extent;
+    }
     MNN_CONCURRENCY_BEGIN(tId, numberThread) {
         for (int b=tId; b<totalCount; b+=numberThread) {
             auto c = b / batch;
-            core->MNNReluWithSlopeChannel((float*)(dstO + sizeQuad * core->bytes * core->pack * b), (const float*)(srcO + sizeQuad * core->pack * core->bytes * b), (const float*)(mSlope.host<uint8_t>() + core->bytes * core->pack * c), sizeQuad, 1);
+            core->MNNReluWithSlopeChannel((float*)(dstO + hw * core->bytes * core->pack * b), (const float*)(srcO + hw * core->pack * core->bytes * b), (const float*)(mSlope.host<uint8_t>() + core->bytes * core->pack * c), hw, 1);
         }
     }
     MNN_CONCURRENCY_END();

+ 75 - 0
source/backend/cpu/CPUStft.cpp

@@ -0,0 +1,75 @@
+//
+//  CPUStft.cpp
+//  MNN
+//
+//  Created by MNN on 2024/11/26.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef MNN_BUILD_AUDIO
+
+#include "backend/cpu/CPUStft.hpp"
+#include "backend/cpu/CPUBackend.hpp"
+#include "core/Concurrency.h"
+#include "core/TensorUtils.hpp"
+#include "core/Macro.h"
+#include "compute/CommonOptFunction.h"
+
+namespace MNN {
+
+CPUStft::CPUStft(Backend* backend, int nfft, int hop_length, bool abs)
+    : Execution(backend), mNfft(nfft), mHopLength(hop_length), mAbs(abs) {
+    // nothing to do
+}
+
+ErrorCode CPUStft::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    auto cpuBn = static_cast<CPUBackend*>(backend());
+    mTmpFrames.buffer().dim[0].extent = cpuBn->threadNumber();
+    mTmpFrames.buffer().dim[1].extent = mNfft;
+    TensorUtils::getDescribe(&mTmpFrames)->dimensionFormat = MNN_DATA_FORMAT_NHWC;
+    mTmpFrames.buffer().dimensions    = 2;
+    mTmpFrames.buffer().type          = inputs[0]->getType();
+    backend()->onAcquireBuffer(&mTmpFrames, Backend::DYNAMIC);
+    backend()->onReleaseBuffer(&mTmpFrames, Backend::DYNAMIC);
+    return NO_ERROR;
+}
+
+ErrorCode CPUStft::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
+    const float* sample = inputs[0]->host<float>();
+    const float* window = inputs[1]->host<float>();
+    float* buffer = mTmpFrames.host<float>();
+    float* output = outputs[0]->host<float>();
+    auto outputShape = outputs[0]->shape();
+    int frames = outputShape[0];
+    int col = outputShape[1];
+    auto cpuBn = static_cast<CPUBackend*>(backend());
+    int threadNum = cpuBn->threadNumber();
+    // div frames to threadNum
+    int threadNumber = std::min(threadNum, frames);
+    int sizeDivide = frames / threadNumber;
+    MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
+        int number = sizeDivide;
+        if (tId == threadNumber - 1) {
+            number = frames - tId * sizeDivide;
+        }
+        for (int i = tId * sizeDivide; i < tId * sizeDivide + number; ++i) {
+            MNNDftAbs(sample + i * mHopLength, window, output + i * col, buffer + tId * mNfft, mNfft);
+        }
+    };
+    MNN_CONCURRENCY_END();
+
+    return NO_ERROR;
+}
+
+class CPUStftCreator : public CPUBackend::Creator {
+public:
+    virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
+                                const MNN::Op* op, Backend* backend) const {
+        auto stft = op->main_as_StftParam();
+        return new CPUStft(backend, stft->n_fft(), stft->hop_length(), stft->abs());
+    }
+};
+
+REGISTER_CPU_OP_CREATOR_AUDIO(CPUStftCreator, OpType_Stft);
+} // namespace MNN
+#endif // MNN_BUILD_AUDIO

+ 31 - 0
source/backend/cpu/CPUStft.hpp

@@ -0,0 +1,31 @@
+//
+//  CPUStft.hpp
+//  MNN
+//
+//  Created by MNN on 2024/11/26.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef MNN_BUILD_AUDIO
+#ifndef CPUStft_hpp
+#define CPUStft_hpp
+
+#include "core/Execution.hpp"
+
+namespace MNN {
+class CPUStft : public Execution {
+public:
+    CPUStft(Backend *backend, int nfft, int hop_length, bool abs);
+    virtual ~CPUStft() = default;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+private:
+    int mNfft, mHopLength;
+    bool mAbs;
+    Tensor mTmpFrames;
+};
+
+} // namespace MNN
+
+#endif /* CPUStft.hpp */
+#endif // MNN_BUILD_AUDIO

+ 47 - 15
source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S

@@ -48,9 +48,9 @@ beq PReluEnd
 cmp r3, #0
 beq PReluEnd
 
-vmov.f32 q12, #0.5
-vmov.f32 q13, #-0.5
 .macro ROUND_TWO x0, x1
+    vmov.f32 q12, #0.5
+    vmov.f32 q13, #-0.5
     vcgt.f32 q10, \x0, #0
     vcgt.f32 q11, \x1, #0
     vbsl.f32 q10, q12, q13
@@ -62,6 +62,8 @@ vmov.f32 q13, #-0.5
 .endm
 
 .macro ROUND_ONE x0
+    vmov.f32 q12, #0.5
+    vmov.f32 q13, #-0.5
     vcgt.f32 q10, \x0, #0
     vbsl.f32 q10, q12, q13
     vadd.f32 \x0, q10, \x0
@@ -69,11 +71,13 @@ vmov.f32 q13, #-0.5
 .endm
 
 vld1.8 d30[0], [r8]
-vdup.8 d30, d30[0]  // inputZeroPoint
-
 vld1.8 d31[0], [r6]
+vdup.8 d30, d30[0]  // inputZeroPoint
 vdup.8 d31, d31[0]  // outputZeroPoint
 
+ldr r6, [r5, #0]    // inputScale
+ldr r8, [r5, #4]    // outputScale
+
 PReluZLoop:
 vld1.32 {q14}, [r2]!
 
@@ -93,17 +97,38 @@ vmovl.s16 q4, d3
 vmovl.s16 q5, d4
 vmovl.s16 q6, d5
 
-vclt.s8 q1, q0, #0
-
 vcvt.f32.s32 q3, q3
 vcvt.f32.s32 q4, q4
 vcvt.f32.s32 q5, q5
 vcvt.f32.s32 q6, q6
-
-vmul.f32 q3, q3, q14
-vmul.f32 q4, q4, q14
-vmul.f32 q5, q5, q14
-vmul.f32 q6, q6, q14
+// *input_scale
+vld1.f32 {d14[0]}, [r6]
+vld1.f32 {d14[1]}, [r8] // outputscale
+vmul.f32 q3, q3, d14[0]
+vmul.f32 q4, q4, d14[0]
+vmul.f32 q5, q5, d14[0]
+vmul.f32 q6, q6, d14[0]
+
+vclt.f32 q0, q3, #0
+vclt.f32 q1, q4, #0
+vclt.f32 q2, q5, #0
+vclt.f32 q12, q6, #0
+
+// *slope
+vmul.f32 q8, q3, q14
+vmul.f32 q9, q4, q14
+vmul.f32 q10, q5, q14
+vmul.f32 q11, q6, q14
+
+vbit.32 q3, q8, q0
+vbit.32 q4, q9, q1
+vbit.32 q5, q10, q2
+vbit.32 q6, q11, q12
+
+vmul.f32 q3, q3, d14[1]
+vmul.f32 q4, q4, d14[1]
+vmul.f32 q5, q5, d14[1]
+vmul.f32 q6, q6, d14[1]
 
 ROUND_TWO q3, q4
 ROUND_TWO q5, q6
@@ -122,8 +147,7 @@ vqmovn.s16 d19, q8
 vmax.s8 q9, q9, q10
 vmin.s8 q9, q9, q11
 
-vbit.8 q0, q9, q1
-vst1.8 {q0}, [r0]!
+vst1.8 {q9}, [r0]!
 
 sub r5, r5, #4
 cmp r5, #4
@@ -139,10 +163,18 @@ vmovl.s8 q1, d0
 vsubw.s8 q1, q1, d30
 
 vmovl.s16 q2, d2
-vclt.s8 d10, d0, #0
 
 vcvt.f32.s32 q2, q2
-vmul.f32 q2, q2, q14
+// *input_scale
+vld1.f32 {d14[0]}, [r6]
+vld1.f32 {d14[1]}, [r8] // outputscale
+vmul.f32 q2, q2, d14[0]
+vclt.f32 q4, q2, #0     // index
+// *slope
+vmul.f32 q3, q2, q14
+vbit q2, q3, q4
+// *output_scale
+vmul.f32 q2, q2, d14[1]
 
 ROUND_ONE q2
 

+ 0 - 225
source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S

@@ -1,225 +0,0 @@
-//
-//  MNNWinogradMatrixProductLeft.S
-//  MNN
-//
-//  Created by MNN on 2018/08/22.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#ifdef __arm__
-#ifndef __aarch64__
-
-#include "MNNAsmGlobal.h"
-
-.text
-.align 5
-
-asm_function MNNWinogradMatrixProductLeft
-//void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-
-//Auto: r0: S, r1:B, r2: M, r3:w
-//Load From sp: r4:h, r5:k, r6:length
-
-push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9
-ldr r4, [sp, #32]
-ldr r5, [sp, #36]
-ldr r6, [sp, #40]
-
-//unitStepInFloat
-mov r8, #16 // 4*sizeof(float)
-mul r8, r6, r8
-
-//srcYUnitStep
-mul lr, r3, r8
-sub lr, lr, r8
-add r7, lr, r8
-
-//B's step
-mov r10, #4
-mul r10, r4, r10
-
-LoopY:
-    push {r0, r3}
-    LoopX:
-        push {r0, r1}
-        vmov.i32 q14, #0
-        mov r11, r6
-        LoopUnitSetZero:
-            vst1.32 {q14}, [r2]!
-            subs r11, r11, #1
-            bne LoopUnitSetZero
-        sub r2, r2, r8
-        mov r12, r5
-
-        LK7:
-        cmp r12, #7
-        blt LK4
-        push {r3-r7}
-        LoopK7:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            vld1.32 {d1[1]}, [r1], r10
-            vld1.32 {d2[0]}, [r1], r10
-            vld1.32 {d2[1]}, [r1], r10
-            vld1.32 {d3[0]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r7
-            add r3, r1, r7
-            add r4, r3, r7
-            add r5, r4, r7
-            add r6, r5, r7
-            add r7, r6, r7
-
-            LoopUnitK7:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-                vld1.32 {q13}, [r4]!
-                vmla.f32 q9, q13, d1[1]
-                vld1.32 {q12}, [r5]!
-                vmla.f32 q8, q12, d2[0]
-                vld1.32 {q13}, [r6]!
-                vmla.f32 q9, q13, d2[1]
-                vld1.32 {q12}, [r7]!
-                vmla.f32 q8, q12, d3[0]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK7
-            sub r2, r2, r8
-            sub r12, r12, #7
-            add r0, r7, lr
-            vmov.32 r1, d30[0]
-            cmp r12, #7
-            bge LoopK7
-        pop {r3-r7}
-
-        LK4:
-        cmp r12, #4
-        blt LK3
-        vmov.32 d30[1], r3
-        vmov.32 d31[0], r4
-        LoopK4:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            vld1.32 {d1[1]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r7
-            add r3, r1, r7
-            add r4, r3, r7
-
-            LoopUnitK4:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-                vld1.32 {q13}, [r4]!
-                vmla.f32 q9, q13, d1[1]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK4
-            sub r2, r2, r8
-            sub r12, r12, #4
-            add r0, r4, lr
-            vmov.32 r1, d30[0]
-            cmp r12, #4
-            bge LoopK4
-        vmov.32 r3, d30[1]
-        vmov.32 r4, d31[0]
-
-        LK3:
-        cmp r12, #3
-        blt LK1
-        vmov.32 d30[1], r3
-        vmov.32 d31[0], r4
-        LoopK3:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r7
-            add r3, r1, r7
-
-            LoopUnitK3:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK3
-            sub r2, r2, r8
-            sub r12, r12, #3
-            add r0, r3, lr
-            vmov.32 r1, d30[0]
-            cmp r12, #3
-            bge LoopK3
-        vmov.32 r3, d30[1]
-        vmov.32 r4, d31[0]
-
-
-
-        LK1:
-        cmp r12, #0
-        beq LKEnd
-
-        LoopK:
-            vld1.32 {d30[0]}, [r1], r10
-
-            vdup.32 q15, d30[0]
-            mov r11, r6
-            LoopUnit:
-                vld1.32 {q0}, [r2]
-                vld1.32 {q1}, [r0]!
-                vmla.f32 q0, q1, q15
-
-                vst1.32 {q0}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnit
-            subs r12, r12, #1
-
-            sub r2, r2, r8
-            add r0, r0, lr
-            bne LoopK
-        LKEnd:
-        pop {r0, r1}
-        subs r3, r3, #1
-        add r0, r0, r8
-        add r2, r2, r8
-
-        bne LoopX
-    pop {r0, r3}
-    add r1, r1, #4 //sizeof(float)
-
-    subs r4, r4, #1
-    bne LoopY
-
-
-
-pop {r4-r8, r10, r11, pc}
-
-#endif
-#endif

+ 0 - 223
source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S

@@ -1,223 +0,0 @@
-//
-//  MNNWinogradMatrixProductRight.S
-//  MNN
-//
-//  Created by MNN on 2018/08/22.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#ifdef __arm__
-#ifndef __aarch64__
-
-#include "MNNAsmGlobal.h"
-
-.text
-.align 5
-
-asm_function MNNWinogradMatrixProductRight
-//void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-
-//Auto: r0: S, r1:B, r2: M, r3:w
-//Load From sp: r4:h, r5:k, r6:length
-
-push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9
-ldr r4, [sp, #32]
-ldr r5, [sp, #36]
-ldr r6, [sp, #40]
-
-//unitStepInFloat
-mov r8, #16 // 4*sizeof(float)
-mul r8, r6, r8
-
-//srcYUnitStep
-mul lr, r5, r8
-
-//B's step
-mov r10, #4
-mul r10, r4, r10
-
-LoopY:
-    push {r1, r3}
-    LoopX:
-        push {r0, r1}
-        vmov.i32 q14, #0
-        mov r11, r6
-        LoopUnitSetZero:
-            vst1.32 {q14}, [r2]!
-            subs r11, r11, #1
-            bne LoopUnitSetZero
-        sub r2, r2, r8
-        mov r12, r5
-
-        LK7:
-        cmp r12, #7
-        blt LK4
-        push {r3-r7}
-        LoopK7:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            vld1.32 {d1[1]}, [r1], r10
-            vld1.32 {d2[0]}, [r1], r10
-            vld1.32 {d2[1]}, [r1], r10
-            vld1.32 {d3[0]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r8
-            add r3, r1, r8
-            add r4, r3, r8
-            add r5, r4, r8
-            add r6, r5, r8
-            add r7, r6, r8
-
-            LoopUnitK7:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-                vld1.32 {q13}, [r4]!
-                vmla.f32 q9, q13, d1[1]
-                vld1.32 {q12}, [r5]!
-                vmla.f32 q8, q12, d2[0]
-                vld1.32 {q13}, [r6]!
-                vmla.f32 q9, q13, d2[1]
-                vld1.32 {q12}, [r7]!
-                vmla.f32 q8, q12, d3[0]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK7
-            sub r2, r2, r8
-            sub r12, r12, #7
-            mov r0, r7
-            vmov.32 r1, d30[0]
-            cmp r12, #7
-            bge LoopK7
-        pop {r3-r7}
-
-        LK4:
-        cmp r12, #4
-        blt LK3
-        vmov.32 d30[1], r3
-        vmov.32 d31[0], r4
-        LoopK4:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            vld1.32 {d1[1]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r8
-            add r3, r1, r8
-            add r4, r3, r8
-
-            LoopUnitK4:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-                vld1.32 {q13}, [r4]!
-                vmla.f32 q9, q13, d1[1]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK4
-            sub r2, r2, r8
-
-            sub r12, r12, #4
-
-            mov r0, r4
-            vmov.32 r1, d30[0]
-            cmp r12, #4
-            bge LoopK4
-        vmov.32 r3, d30[1]
-        vmov.32 r4, d31[0]
-
-        LK3:
-        cmp r12, #3
-        blt LK1
-        vmov.32 d30[1], r3
-        LoopK3:
-            vld1.32 {d0[0]}, [r1], r10
-            vld1.32 {d0[1]}, [r1], r10
-            vld1.32 {d1[0]}, [r1], r10
-            mov r11, r6
-            vmov.32 d30[0], r1
-
-            add r1, r0, r8
-            add r3, r1, r8
-
-            LoopUnitK3:
-                vld1.32 {q8}, [r2]
-                vld1.32 {q12}, [r0]!
-                vmla.f32 q8, q12, d0[0]
-                vld1.32 {q13}, [r1]!
-                vmul.f32 q9, q13, d0[1]
-                vld1.32 {q12}, [r3]!
-                vmla.f32 q8, q12, d1[0]
-
-                vadd.f32 q9, q8, q9
-                vst1.32 {q9}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnitK3
-            sub r2, r2, r8
-
-            sub r12, r12, #3
-
-            mov r0, r3
-            vmov.32 r1, d30[0]
-            cmp r12, #3
-            bge LoopK3
-        vmov.32 r3, d30[1]
-
-
-        LK1:
-        cmp r12, #0
-        beq LKEnd
-
-        LoopK:
-            vld1.32 {d30[0]}, [r1], r10
-
-            vdup.32 q15, d30[0]
-            mov r11, r6
-            LoopUnit:
-                vld1.32 {q0}, [r2]
-                vld1.32 {q1}, [r0]!
-                vmla.f32 q0, q1, q15
-
-                vst1.32 {q0}, [r2]!
-                subs r11, r11, #1
-                bne LoopUnit
-            subs r12, r12, #1
-
-            sub r2, r2, r8
-            bne LoopK
-        LKEnd:
-        pop {r0, r1}
-        subs r3, r3, #1
-        add r2, r2, r8
-        add r1, r1, #4 //sizeof(float)
-
-        bne LoopX
-    pop {r1, r3}
-    add r0, r0, lr
-
-    subs r4, r4, #1
-    bne LoopY
-
-
-
-pop {r4-r8, r10, r11, pc}
-
-#endif
-#endif

+ 46 - 15
source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S

@@ -25,8 +25,10 @@ asm_function MNNReluWithSlopeChannelInt8
 // MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params)
 // Auto load:
 // x0: dst, x1: src, x2: slope, x3: planeNumber, x4: depthQuad, x5: params
-// Load from x5:  x8: inputZeroPoint, x9: outputZeroPoint, x10: minValue, x11: maxValue
+// Load from x5: x9: outputZeroPoint, x10: minValue, x11: maxValue
 
+ldr x12, [x5, #0]
+ldr x13, [x5, #8]
 ldr x8, [x5, #16]
 ldr x9, [x5, #24]
 ldr x10, [x5, #32]
@@ -43,10 +45,12 @@ beq End
 cmp x4, #0
 beq End
 
-ld1r {v29.8b}, [x8] // inputZeroPoint
-ld1r {v28.8b}, [x9] // outputZeroPoint
+ld1r {v29.16b}, [x8] // inputZeroPoint
+ld1r {v28.16b}, [x9] // outputZeroPoint
 dup v26.16b, w10
 dup v27.16b, w11
+ld1r {v24.4s}, [x12] // inputscale
+ld1r {v25.4s}, [x13] // outputscale
 /*
 Quant parameters
 */
@@ -60,7 +64,6 @@ ble PReluL1
 
 PReluL4Loop:
 ld1 {v0.16b}, [x1], #16
-cmlt v30.16b, v0.16b, #0 // mask0: x<0
 
 sxtl v1.8h, v0.8b
 sxtl2 v2.8h, v0.16b
@@ -76,10 +79,33 @@ scvtf v4.4s, v4.4s
 scvtf v5.4s, v5.4s
 scvtf v6.4s, v6.4s
 
-fmul v3.4s, v3.4s, v31.4s
-fmul v4.4s, v4.4s, v31.4s
-fmul v5.4s, v5.4s, v31.4s
-fmul v6.4s, v6.4s, v31.4s
+// input_scale
+fmul v3.4s, v3.4s, v24.4s
+fmul v4.4s, v4.4s, v24.4s
+fmul v5.4s, v5.4s, v24.4s
+fmul v6.4s, v6.4s, v24.4s
+
+fcmle v7.4s, v3.4s, #0
+fcmle v8.4s, v4.4s, #0
+fcmle v9.4s, v5.4s, #0
+fcmle v10.4s, v6.4s, #0
+
+// *slope
+fmul v11.4s, v3.4s, v31.4s
+fmul v12.4s, v4.4s, v31.4s
+fmul v13.4s, v5.4s, v31.4s
+fmul v14.4s, v6.4s, v31.4s
+
+bit v3.16b, v11.16b, v7.16b
+bit v4.16b, v12.16b, v8.16b
+bit v5.16b, v13.16b, v9.16b
+bit v6.16b, v14.16b, v10.16b
+
+// *output_scale
+fmul v3.4s, v3.4s, v25.4s
+fmul v4.4s, v4.4s, v25.4s
+fmul v5.4s, v5.4s, v25.4s
+fmul v6.4s, v6.4s, v25.4s
 
 fcvtas v3.4s, v3.4s
 fcvtas v4.4s, v4.4s
@@ -99,8 +125,7 @@ sqxtn2 v9.16b, v8.8h
 smax v9.16b, v9.16b, v26.16b
 smin v9.16b, v9.16b, v27.16b
 
-bit v0.16b, v9.16b, v30.16b
-st1 {v0.16b}, [x0], #16
+st1 {v9.16b}, [x0], #16
 
 sub x5, x5, #4
 cmp x5, #4
@@ -113,13 +138,20 @@ beq PReluL1End
 
 PReluL1Loop:
 ld1 {v0.s}[0], [x1], #4
-cmlt v30.8b, v0.8b, #0
 
 sxtl v1.8h, v0.8b
 ssubw v1.8h, v1.8h, v29.8b
 sxtl v1.4s, v1.4h
 scvtf v1.4s, v1.4s
-fmul v1.4s, v1.4s, v31.4s
+// *input_scale
+fmul v1.4s, v1.4s, v24.4s
+fcmle v7.4s, v1.4s, #0
+// *slope
+fmul v11.4s, v1.4s, v31.4s
+bit v1.16b, v11.16b, v7.16b
+// *output_scale
+fmul v1.4s, v1.4s, v25.4s
+
 fcvtas v1.4s, v1.4s
 sqxtn v1.4h, v1.4s
 saddw v1.8h, v1.8h, v28.8b
@@ -127,8 +159,7 @@ sqxtn v1.8b, v1.8h
 smax v1.8b, v1.8b, v26.8b
 smin v1.8b, v1.8b, v27.8b
 
-bit v0.8b, v1.8b, v30.8b
-st1 {v0.s}[0], [x0], #4
+st1 {v1.s}[0], [x0], #4
 subs x5, x5, #1
 bne PReluL1Loop
 
@@ -144,4 +175,4 @@ End:
     ldp d14, d15, [sp], #64
     ret
 
-#endif
+#endif

+ 0 - 171
source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S

@@ -1,171 +0,0 @@
-//
-//  MNNWinogradMatrixProductLeft.S
-//  MNN
-//
-//  Created by MNN on 2018/08/22.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#ifdef __aarch64__
-
-#include "MNNAsmGlobal.h"
-
-.text
-.align 5
-
-asm_function MNNWinogradMatrixProductLeft
-//void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-
-//Auto: x0: S, x1:B, x2: M, x3:w, x4:h, x5:k, x6:length
-
-//unitStepInFloat
-mov x8, #16 // 4*sizeof(float)
-mul x8, x6, x8
-
-//srcYUnitStep
-mul x9, x3, x8
-sub x9, x9, x8
-add x7, x9, x8
-
-//B's step
-mov x10, #4
-mul x10, x4, x10
-
-LoopY:
-    mov v4.d[0], x0
-    mov v4.d[1], x3
-    LoopX:
-        mov v5.d[0], x0
-        mov v5.d[1], x1
-        movi v30.4s, #0
-        mov x11, x6
-        LoopUnitSetZero:
-            st1 {v30.4s}, [x2], #16
-            subs x11, x11, #1
-            bne LoopUnitSetZero
-        sub x2, x2, x8
-        mov x12, x5
-
-        LK4:
-        cmp x12, #4
-        blt LK3
-        mov v6.d[0], x3
-        mov v6.d[1], x4
-        LoopK4:
-            ld1 {v0.s}[0], [x1], x10
-            ld1 {v0.s}[1], [x1], x10
-            ld1 {v0.s}[2], [x1], x10
-            ld1 {v0.s}[3], [x1], x10
-            mov x11, x6
-            mov v7.d[0], x1
-
-            add x1, x0, x7
-            add x3, x1, x7
-            add x4, x3, x7
-
-            LoopUnitK4:
-                ld1 {v16.4s}, [x2]
-                ld1 {v20.4s}, [x0], #16
-                fmla v16.4s, v20.4s, v0.s[0]
-                ld1 {v21.4s}, [x1], #16
-                fmul v17.4s, v21.4s, v0.s[1]
-                ld1 {v20.4s}, [x3], #16
-                fmla v16.4s, v20.4s, v0.s[2]
-                ld1 {v21.4s}, [x4], #16
-                fmla v17.4s, v21.4s, v0.s[3]
-
-                fadd v17.4s, v16.4s, v17.4s
-                st1 {v17.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnitK4
-            sub x2, x2, x8
-
-            sub x12, x12, #4
-
-            add x0, x4, x9
-            mov x1, v7.d[0]
-            cmp x12, #4
-            bge LoopK4
-        mov x3, v6.d[0]
-        mov x4, v6.d[1]
-
-        LK3:
-        cmp x12, #3
-        blt LK1
-        mov v6.d[0], x3
-        LoopK3:
-            ld1 {v0.s}[0], [x1], x10
-            ld1 {v0.s}[1], [x1], x10
-            ld1 {v0.s}[2], [x1], x10
-            mov x11, x6
-            mov v7.d[0], x1
-
-            add x1, x0, x7
-            add x3, x1, x7
-
-            LoopUnitK3:
-                ld1 {v16.4s}, [x2]
-                ld1 {v20.4s}, [x0], #16
-                fmla v16.4s, v20.4s, v0.s[0]
-                ld1 {v21.4s}, [x1], #16
-                fmul v17.4s, v21.4s, v0.s[1]
-                ld1 {v20.4s}, [x3], #16
-                fmla v16.4s, v20.4s, v0.s[2]
-
-                fadd v17.4s, v16.4s, v17.4s
-                st1 {v17.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnitK3
-            sub x2, x2, x8
-
-            sub x12, x12, #3
-
-            add x0, x3, x9
-            mov x1, v7.d[0]
-            cmp x12, #3
-            bge LoopK3
-        mov x3, v6.d[0]
-
-
-        LK1:
-        cmp x12, #0
-        beq LKEnd
-
-        LoopK:
-            ld1 {v31.s}[0], [x1], x10
-
-            dup v31.4s, v31.s[0]
-            mov x11, x6
-            LoopUnit:
-                ld1 {v0.4s}, [x2]
-                ld1 {v1.4s}, [x0], #16
-                fmla v0.4s, v1.4s, v31.4s
-
-                st1 {v0.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnit
-            subs x12, x12, #1
-
-            sub x2, x2, x8
-            add x0, x0, x9
-            bne LoopK
-        LKEnd:
-        mov x0, v5.d[0]
-        mov x1, v5.d[1]
-        subs x3, x3, #1
-        add x0, x0, x8
-        add x2, x2, x8
-
-        bne LoopX
-    mov x0, v4.d[0]
-    mov x3, v4.d[1]
-    add x1, x1, #4 //sizeof(float)
-
-    subs x4, x4, #1
-    bne LoopY
-
-
-
-    ret
-
-#endif

+ 0 - 164
source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S

@@ -1,164 +0,0 @@
-//
-//  MNNWinogradMatrixProductRight.S
-//  MNN
-//
-//  Created by MNN on 2018/08/22.
-//  Copyright © 2018, Alibaba Group Holding Limited
-//
-
-#ifdef __aarch64__
-
-#include "MNNAsmGlobal.h"
-
-.text
-.align 5
-
-asm_function MNNWinogradMatrixProductRight
-//void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-
-//Auto: x0: S, x1:B, x2: M, x3:w, x4:h, x5:k, x6:length
-
-//unitStepInFloat
-mov x8, #16 // 4*sizeof(float)
-mul x8, x6, x8
-
-//srcYUnitStep
-mul x9, x5, x8
-
-//B's step
-mov x10, #4
-mul x10, x4, x10
-
-LoopY:
-    mov v4.d[0], x1
-    mov v4.d[1], x3
-    LoopX:
-        mov v5.d[0], x0
-        mov v5.d[1], x1
-        movi v30.4s, #0
-        mov x11, x6
-        LoopUnitSetZero:
-            st1 {v30.4s}, [x2], #16
-            subs x11, x11, #1
-            bne LoopUnitSetZero
-        sub x2, x2, x8
-        mov x12, x5
-
-        LK4:
-        cmp x12, #4
-        blt LK3
-        mov v6.d[0], x3
-        mov v6.d[1], x4
-        LoopK4:
-            ld1 {v0.s}[0], [x1], x10
-            ld1 {v0.s}[1], [x1], x10
-            ld1 {v0.s}[2], [x1], x10
-            ld1 {v0.s}[3], [x1], x10
-            mov x11, x6
-            mov v7.d[0], x1
-
-            add x1, x0, x8
-            add x3, x1, x8
-            add x4, x3, x8
-
-            LoopUnitK4:
-                ld1 {v16.4s}, [x2]
-                ld1 {v20.4s}, [x0], #16
-                fmla v16.4s, v20.4s, v0.s[0]
-                ld1 {v21.4s}, [x1], #16
-                fmul v17.4s, v21.4s, v0.s[1]
-                ld1 {v20.4s}, [x3], #16
-                fmla v16.4s, v20.4s, v0.s[2]
-                ld1 {v21.4s}, [x4], #16
-                fmla v17.4s, v21.4s, v0.s[3]
-
-                fadd v17.4s, v16.4s, v17.4s
-                st1 {v17.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnitK4
-            sub x2, x2, x8
-            sub x12, x12, #4
-            mov x0, x4
-
-            mov x1, v7.d[0]
-            cmp x12, #4
-            bge LoopK4
-        mov x3, v6.d[0]
-        mov x4, v6.d[1]
-
-        LK3:
-        cmp x12, #3
-        blt LK1
-        mov v6.d[0], x3
-        LoopK3:
-            ld1 {v0.s}[0], [x1], x10
-            ld1 {v0.s}[1], [x1], x10
-            ld1 {v0.s}[2], [x1], x10
-            mov x11, x6
-            mov v7.d[0], x1
-
-            add x1, x0, x8
-            add x3, x1, x8
-
-            LoopUnitK3:
-                ld1 {v16.4s}, [x2]
-                ld1 {v20.4s}, [x0], #16
-                fmla v16.4s, v20.4s, v0.s[0]
-                ld1 {v21.4s}, [x1], #16
-                fmul v17.4s, v21.4s, v0.s[1]
-                ld1 {v20.4s}, [x3], #16
-                fmla v16.4s, v20.4s, v0.s[2]
-
-                fadd v17.4s, v16.4s, v17.4s
-                st1 {v17.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnitK3
-            sub x2, x2, x8
-            sub x12, x12, #3
-            mov x0, x4
-            mov x1, v7.d[0]
-            cmp x12, #3
-            bge LoopK3
-        mov x3, v6.d[0]
-
-        LK1:
-        cmp x12, #0
-        beq LKEnd
-
-        LoopK:
-            ld1 {v31.s}[0], [x1], x10
-
-            dup v31.4s, v31.s[0]
-            mov x11, x6
-            LoopUnit:
-                ld1 {v0.4s}, [x2]
-                ld1 {v1.4s}, [x0], #16
-                fmla v0.4s, v1.4s, v31.4s
-
-                st1 {v0.4s}, [x2], #16
-                subs x11, x11, #1
-                bne LoopUnit
-            subs x12, x12, #1
-
-            sub x2, x2, x8
-            bne LoopK
-        LKEnd:
-        mov x0, v5.d[0]
-        mov x1, v5.d[1]
-        subs x3, x3, #1
-        add x2, x2, x8
-        add x1, x1, #4 //sizeof(float)
-
-        bne LoopX
-    mov x1, v4.d[0]
-    mov x3, v4.d[1]
-    add x0, x0, x9
-
-    subs x4, x4, #1
-    bne LoopY
-
-
-
-    ret
-
-#endif

+ 20 - 2
source/backend/cpu/compute/CommonOptFunction.cpp

@@ -23,6 +23,9 @@
 #include "../CPUBinary.hpp"
 #include "../CPUUnary.hpp"
 #include "../CPUPool.hpp"
+#ifndef M_PI
+#define M_PI 3.141592654
+#endif
 #define PACK 4
 #define FLOAT float
 using Vec = MNN::Math::Vec<float, 4>;
@@ -314,7 +317,7 @@ static void MNNSumByAxisLForMatmul_A(float* dest, int8_t* source, const float* s
         dest += (step * blockNum);
         realDstCount -= step;
         srcInt8 += col_buffer_unit_size;
-    } while(realDstCount > 0); 
+    } while(realDstCount > 0);
 }
 
 template<typename T>
@@ -3099,6 +3102,21 @@ void MNNSiLuLowp(float* dst, const float* src, size_t dataSize) {
 #endif
 }
 
+void MNNDftAbs(const float* input, const float* window, float* output, float* buffer, int nfft) {
+    for (int i = 0; i < nfft; ++i) {
+        buffer[i] = input[i] * window[i];
+    }
+    for (int k = 0; k < nfft / 2 + 1; ++k) {
+        float real_sum = 0.f, imag_sum = 0.f;
+        for (int n = 0; n < nfft; ++n) {
+            float angle = 2 * M_PI * k * n / nfft;
+            real_sum += buffer[n] * std::cos(angle);
+            imag_sum -= buffer[n] * std::sin(angle);
+        }
+        output[k] = std::sqrt(real_sum * real_sum + imag_sum * imag_sum);
+    }
+}
+
 static void _MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFunctions::MNNPackedSparseMatMul& packedSparseMatMul) {
     if(sparseBlockOC == 4) {
         packedSparseMatMul = MNNPackedSparseMatMulEpx4;
@@ -3202,7 +3220,7 @@ void MNNCoreFunctionInit() {
     gCoreFunction->MNNFp16ToFp8 = MNNFp16ToFp8;
     gCoreFunction->MNNFp8ToFp32 = MNNFp8ToFp32;
     gCoreFunction->MNNFp8ToFp16 = MNNFp8ToFp16;
-    
+
     // MatMul
     gCoreFunction->MNNGetMatMulPackMode = MNNGetMatMulPackMode;
     gCoreFunction->MNNPackC4ForMatMul_A = MNNPackC4ForMatMul_A;

+ 2 - 1
source/backend/cpu/compute/CommonOptFunction.h

@@ -101,6 +101,7 @@ void MNNGeluCommon(float* dst, const float* src, size_t size);
 void MNNGeluStandardCommon(float* dst, const float* src, size_t size);
 void MNNSoftmax(float* dest, const float* source, size_t size);
 void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false);
+void MNNDftAbs(const float* input, const float* window, float* output, float* buffer, int nfft);
 
 // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n
 void MNNGetMatMulPackMode(int* eP, int *lP, int* hP);
@@ -313,7 +314,7 @@ struct CoreFunctions {
     void(*MNNPoolingMax)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput,
                            int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth,
                            int strideHeight, int padWidth, int padHeight, int padType, int countType);
-    
+
     void(*MNNPoolingMaxWithRedice)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput,
                            int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth,
                            int strideHeight, int padWidth, int padHeight, int padType, int countType, int *RediceOutput);

+ 11 - 13
source/backend/cpu/compute/Int8FunctionsOpt.cpp

@@ -29,7 +29,7 @@ void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const
 void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx);
 
 void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor);
-void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params);
+void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack);
 #if defined(__aarch64__) // aarch32 sdot workaround
 void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
                                         const QuanPostTreatParameters* post, size_t realDstCount);
@@ -1543,7 +1543,7 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src,
     }
 }
 
-static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) {
+static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack) {
 #ifdef MNN_USE_SSE
 float offset = 128.f;
 uint8_t* srcPtr = (uint8_t*)src;
@@ -1554,24 +1554,22 @@ const int8_t* srcPtr = src;
 int8_t* dstPtr = dst;
 #endif
     float mulVal = 0.f;
-    float inputScale = params->inputScale[0];
-    float outputScale = params->outputScale[0];
     float inputZero = static_cast<float>(params->inputZeroPoint[0]) + offset;
     float outputZero = static_cast<float>(params->outputZeroPoint[0]) + offset;
     int32_t minval = params->minValue + offset;
     int32_t maxval = params->maxValue + offset;
     for (int j = 0;j < depthQuad; ++j) {
-        const float* slopeZ = slope + 4 * j;
-        const auto srcZ = srcPtr + 4 * j * planeNumber;
-        auto dstZ = dstPtr + 4 * j * planeNumber;
+        const float* slopeZ = slope + pack * j;
+        const auto srcZ = srcPtr + pack * j * planeNumber;
+        auto dstZ = dstPtr + pack * j * planeNumber;
         for (int i = 0; i < planeNumber; ++i) {
-            for (int c = 0; c < 4; ++c) {
-                if ((float)srcZ[4 * i + c] < inputZero) {
-                    mulVal = (srcZ[4 * i + c] - inputZero) * slopeZ[c];
-                    dstZ[4 * i + c] = ALIMIN(ALIMAX(static_cast<int32_t>(roundf(mulVal)) + outputZero, minval), maxval);
-                } else {
-                    dstZ[4 * i + c] = srcZ[4 * i + c];
+            for (int c = 0; c < pack; ++c) {
+                float valInput = (static_cast<float>(srcZ[pack * i + c]) - inputZero) * params->inputScale[0];
+                if (valInput < 0) {
+                    valInput *= slopeZ[c];
                 }
+                auto mulVal = valInput * params->outputScale[0] + outputZero;
+                dstZ[pack * i + c] = ALIMIN(ALIMAX(static_cast<int32_t>(roundf(mulVal)), minval), maxval);
             }
         }
     }

+ 1 - 1
source/backend/cpu/compute/Int8FunctionsOpt.h

@@ -113,7 +113,7 @@ struct CoreInt8Functions {
     void (*MNNAvgPoolInt8)(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor);
     
     // Relu
-    void (*MNNReluWithSlopeChannelInt8)(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params);
+    void (*MNNReluWithSlopeChannelInt8)(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack);
 };
 void MNNCoreInt8FunctionInit();
 CoreInt8Functions* MNNGetInt8CoreFunctions();

+ 0 - 67
source/backend/cpu/compute/WinogradOptFunction.cpp

@@ -16,77 +16,10 @@
 
 using Vec4 = MNN::Math::Vec<float, 4>;
 #define DEFAULT_UNIT 8
-extern "C" {
-void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                  size_t length);
-void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                   size_t length);
-}
-
-#ifndef MNN_USE_NEON
-
-// M = BT * S , M = w*h * l, S = w*k * l, B = h*k
-void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                  size_t length) {
-    auto unitStep = 4 * length;
-    for (int y = 0; y < h; ++y) {
-        auto dstY = M + y * w * unitStep;
-        for (int x = 0; x < w; ++x) {
-            auto dstX = dstY + x * unitStep;
-            auto srcX = S + x * unitStep;
-            ::memset(dstX, 0, unitStep * sizeof(float));
-            for (int i = 0; i < k; ++i) {
-                auto b    = B[i * h + y];
-                auto srcY = srcX + i * w * unitStep;
-                if (0.0f == b) {
-                    continue;
-                }
-                for (int j = 0; j < unitStep; ++j) {
-                    dstX[j] += srcY[j] * b;
-                }
-            }
-        }
-    }
-}
 
-// M = S * B , M = w*h * l, S = k*h * l, B = w*k
-void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                   size_t length) {
-    auto unitStep = 4 * length;
-    for (int y = 0; y < h; ++y) {
-        auto dstY = M + y * w * unitStep;
-        auto srcY = S + y * k * unitStep;
-
-        for (int x = 0; x < w; ++x) {
-            auto dstX = dstY + x * unitStep;
-            ::memset(dstX, 0, unitStep * sizeof(float));
-            for (int i = 0; i < k; ++i) {
-                auto srcX = srcY + i * unitStep;
-                auto b    = B[i * h + x];
-                if (0.0f == b) {
-                    continue;
-                }
-                for (int j = 0; j < unitStep; ++j) {
-                    dstX[j] += srcX[j] * b;
-                }
-            }
-        }
-    }
-}
-#endif
 
 namespace MNN {
 
-
-void WinogradFunction::productLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                   size_t length) {
-    MNNWinogradMatrixProductLeft(S, B, M, w, h, k, length);
-}
-
-void WinogradFunction::productRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k,
-                                    size_t length) {
-    MNNWinogradMatrixProductRight(S, B, M, w, h, k, length);
-}
 int WinogradFunction::getPreferNumber() {
     return DEFAULT_UNIT;
 }

+ 0 - 3
source/backend/cpu/compute/WinogradOptFunction.hpp

@@ -15,9 +15,6 @@
 namespace MNN {
 class WinogradFunction {
 public:
-    static void productLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-    static void productRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length);
-
     static int getPreferNumber();
 
     typedef void (*TransformFunc)(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep);

+ 0 - 1
source/backend/cpu/x86_x64/FunctionDispatcher.cpp

@@ -132,7 +132,6 @@ void MNNInt8FunctionInit() {
     auto core = MNN::MNNGetInt8CoreFunctions();
     core->MNNAvgPoolInt8 = MNNAvgPoolUint8;
     core->MNNMaxPoolInt8 = MNNMaxPoolInt8_;
-    core->MNNReluWithSlopeChannelInt8 = _SSE_MNNReluWithSlopeChannelInt8;
     if (cpuFlags & libyuv::kCpuHasSSE41) {
         core->MNNFloat2Int8 = _SSE_MNNFloat2Int8;
         core->MNNInt8ScaleToFloat = _SSE_MNNInt8ScaleToFloat;

+ 1 - 1
source/backend/cpu/x86_x64/sse/FunctionSummary.hpp

@@ -36,7 +36,7 @@ void _SSE_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride,
 void _SSE_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad);
 
 void _SSE_MNNGelu(float* dst, const float* src, size_t size, float* parameters);
-void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params);
+void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack);
 
 void _SSE_MNNHardSwish(float* dst, const float* src, size_t size);
 

+ 0 - 43
source/backend/cpu/x86_x64/sse/MathFunctions.cpp

@@ -290,46 +290,3 @@ void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float
     }
 }
 
-void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) {
-    uint8_t* dstO = (uint8_t*)dst;
-    uint8_t* srcO = (uint8_t*)src;
-    auto outputZero = _mm_set1_ps(static_cast<float>(params->outputZeroPoint[0]));
-    __m128 maxValue = _mm_set1_ps(params->maxValue);
-    __m128 minValue = _mm_set1_ps(params->minValue);
-    auto offset = _mm_set1_epi32(128);
-    auto zero = _mm_set1_epi32(0);
-    __m128 plus = _mm_set1_ps(0.5f);
-    __m128 minus = _mm_set1_ps(-0.5f);
-    __m128i zeroPointValue = _mm_set1_epi32(static_cast<int32_t>(params->inputZeroPoint[0]) + 128);
-    for (int j = 0;j < depthQuad; ++j) {
-        auto slopeZ = _mm_loadu_ps(slope + 4 * j);
-        const uint8_t* srcZ = srcO + 4 * j * planeNumber;
-        uint8_t* dstZ = dstO + 4 * j * planeNumber;
-        int32_t srcZ_ext[4] = {*(int32_t*)srcZ, 0, 0, 0};
-        for (int i = 0; i < planeNumber; ++i) {
-            // auto srcData8 = _mm_loadu_si32(srcZ);
-            auto srcData8 = _mm_castps_si128(_mm_loadu_ps((float*)srcZ_ext));
-            auto srcData16 = _mm_unpacklo_epi8(srcData8, zero);
-            auto srcData32 = _mm_unpacklo_epi16(srcData16, zero);
-            srcData32 = _mm_sub_epi32(srcData32, zeroPointValue);
-            auto srcDataf  = _mm_cvtepi32_ps(srcData32);
-            auto mask1 = _mm_cmplt_ps(srcDataf, _mm_castsi128_ps(zero));
-            auto mask0 = _mm_cmpge_ps(srcDataf, _mm_castsi128_ps(zero));
-            auto f = _mm_mul_ps(srcDataf, slopeZ);
-            f = _mm_add_ps(f, outputZero);
-            f = _mm_min_ps(f, maxValue);
-            f = _mm_max_ps(f, minValue);
-            auto r = _mm_add_ps(_mm_and_ps(srcDataf, mask0), _mm_and_ps(f, mask1));
-            auto m0 = _mm_cmplt_ps(r, _mm_castsi128_ps(zero));
-            m0 = _mm_blendv_ps(plus, minus, m0);
-            r = _mm_add_ps(r, m0);
-            // Round to zero
-            auto d0 = _mm_cvtps_epi32(_mm_round_ps(r, 3));
-            d0 = _mm_add_epi32(d0, offset);
-            d0 = _mm_packs_epi32(d0, d0);
-            d0 = _mm_packus_epi16(d0, d0);
-            *((int*)dstZ + i) = _mm_cvtsi128_si32(d0);
-        }
-    }
-}
-

+ 32 - 34
source/backend/metal/AllShader.cpp

@@ -1428,7 +1428,6 @@ const char* shader_MetalDeconvolution_metal =
 " int output_height;\n"
 " int output_size;\n"
 " int output_slice;\n"
-" \n"
 " int kernel_x;\n"
 " int kernel_y;\n"
 " int kernel_size;\n"
@@ -1438,12 +1437,10 @@ const char* shader_MetalDeconvolution_metal =
 " int pad_y;\n"
 " int dilation_x;\n"
 " int dilation_y;\n"
-" \n"
 " int delta_ky;\n"
 " int delta_kx;\n"
 " int delta_iy;\n"
 " int delta_ix;\n"
-" int has_bias;\n"
 " int batch;\n"
 " conv_activation_type activation;\n"
 "};\n"
@@ -1494,8 +1491,8 @@ const char* shader_MetalDeconvolution_metal =
 " const device M4 *biasTerms [[buffer(4)]],\n"
 " uint3 gid [[thread_position_in_grid]]) {\n"
 " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
-" \n"
-" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z/cst.batch)]);\n"
+" int oz=(int)gid.z/cst.batch;\n"
+" FLOAT4 result=FLOAT4(biasTerms[oz]);\n"
 " \n"
 " int oy=(int)gid.y+cst.pad_y;\n"
 " int ox=(int)gid.x+cst.pad_x;\n"
@@ -1512,7 +1509,7 @@ const char* shader_MetalDeconvolution_metal =
 " int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
 " int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
 " \n"
-" auto z_wt=wt+(int)gid.z*cst.kernel_size;\n"
+" auto z_wt=wt+oz*cst.kernel_size;\n"
 " auto z_in=in+(int)gid.z*cst.input_size;\n"
 " for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
 " for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
@@ -1670,6 +1667,7 @@ const char* shader_MetalConvolution1x1_metal =
 " int batch;\n"
 " int block_size;\n"
 " conv_activation_type activation;\n"
+" float scale_coef;\n"
 "};\n"
 "kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
 " device M4 *out [[buffer(1)]],\n"
@@ -1711,7 +1709,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device MNN::char4x4 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[thread_position_in_grid]]) {\n"
 " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
 " int rx=gid.x*CONV_UNROLL;\n"
@@ -1724,8 +1722,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
-" FLOAT4 bs0=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
-" FLOAT4 bs1=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
+" FLOAT4 bs0=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 bs1=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " FLOAT4 scale=bs0;\n"
 " FLOAT4 dequant_bias=bs1;\n"
 " int zmin=bi*block;\n"
@@ -1759,7 +1757,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device uchar2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[threadgroup_position_in_grid]],\n"
 " uint tiitg[[thread_index_in_threadgroup]],\n"
 " uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
@@ -1793,8 +1791,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
 " // [N/4,cst.block_size,2/*scale_bias*/,N4]\n"
-" FLOAT4 scale=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin+kl; z<zmax; z += 8) {\n"
@@ -1849,7 +1847,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device uchar2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[threadgroup_position_in_grid]],\n"
 " uint tiitg[[thread_index_in_threadgroup]],\n"
 " uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
@@ -1884,8 +1882,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
 " // [N/4,cst.block_size,2/*scale_bias*/,N4]\n"
-" FLOAT4 scale=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(idx_n4*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin+kl; z<zmax; z += 2) {\n"
@@ -1945,7 +1943,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device uchar2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[threadgroup_position_in_grid]],\n"
 " uint tiitg[[thread_index_in_threadgroup]],\n"
 " uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
@@ -1980,10 +1978,10 @@ const char* shader_MetalConvolution1x1_metal =
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
 " // [N/4,cst.block_size,2/*scale_bias*/,N4]\n"
-" FLOAT4 scale0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+1]);\n"
-" FLOAT4 scale1=FLOAT4(dequantScale[2*(idx_n41*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias1=FLOAT4(dequantScale[2*(idx_n41*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 scale1=FLOAT4(dequantScale[2*(idx_n41*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias1=FLOAT4(dequantScale[2*(idx_n41*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin+kl; z<zmax; z += 2) {\n"
@@ -2048,7 +2046,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device uchar2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[threadgroup_position_in_grid]],\n"
 " uint tiitg[[thread_index_in_threadgroup]],\n"
 " uint tiisg[[thread_index_in_simdgroup]],\n"
@@ -2106,8 +2104,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
 " // [N/4,cst.block_size,2/*scale_bias*/,N4]\n"
-" FLOAT4 scale0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias0=FLOAT4(dequantScale[2*(idx_n40*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " \n"
@@ -2174,7 +2172,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid [[thread_position_in_grid]]) {\n"
 " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
 " int rx=gid.x*CONV_UNROLL;\n"
@@ -2187,8 +2185,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
 " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
 " for (int bi=0; bi<cst.block_size; ++bi) {\n"
-" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin; z<zmax; z++) {\n"
@@ -2226,7 +2224,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid[[threadgroup_position_in_grid]],\n"
 " uint tiisg[[thread_index_in_simdgroup]],\n"
 " uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
@@ -2250,8 +2248,8 @@ const char* shader_MetalConvolution1x1_metal =
 " int outer_index=(tiisg)/middle_step;\n"
 " \n"
 " for (int bi= outer_index; bi<cst.block_size; bi += outer_step) {\n"
-" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
-" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
+" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin+middle_index; z<zmax; z += middle_step) {\n"
@@ -2279,7 +2277,7 @@ const char* shader_MetalConvolution1x1_metal =
 " constant conv1x1_constants& cst [[buffer(2)]],\n"
 " const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
 " const device M4 *biasTerms [[buffer(4)]],\n"
-" const device float4 *dequantScale [[buffer(5)]],\n"
+" const device M4 *dequantScale [[buffer(5)]],\n"
 " uint3 gid[[threadgroup_position_in_grid]],\n"
 " uint tiisg[[thread_index_in_simdgroup]],\n"
 " uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
@@ -2306,10 +2304,10 @@ const char* shader_MetalConvolution1x1_metal =
 " \n"
 " for (int bi= outer_index; bi<cst.block_size; bi += outer_step) {\n"
 " const int quant_offset=2*(uz*cst.block_size+bi);\n"
-" FLOAT4 scale0=FLOAT4(dequantScale[quant_offset+0]);\n"
-" FLOAT4 dequant_bias0=FLOAT4(dequantScale[quant_offset+1]);\n"
-" FLOAT4 scale1=FLOAT4(dequantScale[quant_offset+(cst.block_size << 1)]);\n"
-" FLOAT4 dequant_bias1=FLOAT4(dequantScale[quant_offset+(cst.block_size << 1)+1]);\n"
+" FLOAT4 scale0=FLOAT4(dequantScale[quant_offset+0])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias0=FLOAT4(dequantScale[quant_offset+1])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 scale1=FLOAT4(dequantScale[quant_offset+(cst.block_size << 1)])/(FLOAT)cst.scale_coef;\n"
+" FLOAT4 dequant_bias1=FLOAT4(dequantScale[quant_offset+(cst.block_size << 1)+1])/(FLOAT)cst.scale_coef;\n"
 " int zmin=bi*block;\n"
 " int zmax=min(zmin+block,cst.input_slice);\n"
 " for (int z=zmin+middle_index; z<zmax; z += middle_step) {\n"

File diff suppressed because it is too large
+ 164 - 668
source/backend/metal/MetalAttention.mm


+ 636 - 0
source/backend/metal/MetalAttentionShader.hpp

@@ -0,0 +1,636 @@
+//
+//  MetalAttentionShader.hpp
+//  MNN
+//
+//  Created by MNN on b'2024/12/03'.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#if MNN_METAL_ENABLED
+#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
+
+const char* gMatMulDivMask = R"metal(
+#include <metal_stdlib>
+#include <simd/simd.h>
+using namespace metal;
+struct Param {
+    int query_seq_len;
+    int key_seq_len;
+    int head_num;
+    int group;
+    int head_dim;
+    float scale;
+    int max_kv_len;
+};
+#define SIMD_GROUP_WIDTH 32
+
+kernel void prefill_qk(const device T* input0 [[buffer(0)]],
+    device T* output [[buffer(1)]],
+    device T* past_key [[buffer(2)]],
+#ifdef FLOAT_MASK
+    const device T* mask [[buffer(3)]],
+#else
+    const device int* mask [[buffer(3)]],
+#endif
+    constant Param& param [[buffer(4)]],
+#ifdef SIMD_GROUP_MATRIX
+    uint3 gid[[threadgroup_position_in_grid]],
+    uint tiitg[[thread_index_in_threadgroup]],
+    uint tiisg[[thread_index_in_simdgroup]],
+    uint sgitg[[simdgroup_index_in_threadgroup]]
+#else
+    uint3 gid[[thread_position_in_grid]]
+#endif
+) {
+#ifdef SIMD_GROUP_MATRIX
+
+    /*
+     // Read:
+     ftype 0~127   ---> input: [M16, K8]
+     ftype 128~255 ---> input: [K8, N16]
+     // Write:
+     ftype 0~255 ---> input: [N2, M2, M8, N8]
+     */
+    
+    simdgroup_float8x8 sga[2];
+    simdgroup_float8x8 sgb[2];
+    simdgroup_float8x8 sgd[4];
+    for (int i = 0; i < 4; i++){
+        sgd[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+
+    int kl = tiitg % 2;// 0~1
+    int rcl = tiitg / 2;// 0~15
+
+    const int slq = gid.x; // q_seq_len/16 -> M/16
+    const int slk = gid.y; // k_seq_len/16 -> N/16
+    const int z = gid.z; // head_num
+
+    /** Q:
+     threadgroup: [M16, K8]
+     each thread: K4
+     layout: [M, B, K] -> [M/16, M16, B, K/8, K2, K4]
+     index : [slq, rcl, z, 0, kl, K4]
+     offset: ((slq * 16 + rcl) * B + z) * K + (0 * 2 + kl) * 4 + 0
+     */
+    /** K:
+     threadgroup: [K8, N16]
+     each thread: N4
+     layout: [N, B/G, K] -> [N/16, N16, B/G, K/8, K2, K4]
+     index : [slk, rcl, B/G, 0, kl, 0]
+     offset: ((slk * 16 + rcl) * B/G + z/G) * K + 0 * 8 + kl * 4 + 0
+     */
+    /** output:
+     threadgroup: [M16, N16]
+     each thread: N8
+     layout: [B, M, N] -> [B, M/16, M16, N/16, N2, N8]
+     index : [z, sl, rcl, kl, 0]
+     offset: (z * M + sl * 16 + rcl) * N + slk * 16 + kl * 8 + 0
+     */
+
+    int group = param.group;
+    int zin = z / param.group;
+    int q_seq_len = param.query_seq_len;
+    int k_seq_len = param.key_seq_len;
+    int head_num = param.head_num;
+    int head_dim = param.head_dim;
+
+    threadgroup float sdata[256] = {0.f};
+
+    int idx_slq = slq * 16 + rcl < q_seq_len ? slq * 16 + rcl : q_seq_len - 1;
+    int idx_slk = slk * 16 + rcl < k_seq_len ? slk * 16 + rcl : k_seq_len - 1;
+
+    auto A_offset = input0 + (idx_slq * head_num + z) * head_dim + (0 * 2 + kl) * 4 + 0;
+    auto B_offset = past_key + (idx_slk * head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0;
+
+    for(int i = 0; i < head_dim; i += 8){
+        sdata[rcl * 8 + kl * 4 + 0] = A_offset[i + 0];
+        sdata[rcl * 8 + kl * 4 + 1] = A_offset[i + 1];
+        sdata[rcl * 8 + kl * 4 + 2] = A_offset[i + 2];
+        sdata[rcl * 8 + kl * 4 + 3] = A_offset[i + 3];
+        
+        sdata[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0];
+        sdata[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1];
+        sdata[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2];
+        sdata[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        simdgroup_load(sga[0], (const threadgroup float*)sdata, 8);
+        simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8);
+        
+        simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16);
+        simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16);
+        
+        simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]);
+        simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]);
+        simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]);
+        simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    simdgroup_store(sgd[0], (threadgroup float*)sdata, 8);
+    simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8);
+    simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8);
+    simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8);
+    
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // [N2, M2, M8, N8]
+    float Vscale = (float)param.scale;
+
+    auto xy_out = output + (z * q_seq_len + slq * 16 + rcl) * k_seq_len + slk * 16 + kl * 8 + 0;
+    if(slq * 16 + rcl < q_seq_len) {
+        if(slk * 16 + kl * 8 + 0 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 0] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 0))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 0))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[0] = out0;
+        }
+        if(slk * 16 + kl * 8 + 1 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 1] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 1))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 1))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[1] = out0;
+        }
+        if(slk * 16 + kl * 8 + 2 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 2] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 2))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 2))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[2] = out0;
+        }
+        if(slk * 16 + kl * 8 + 3 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 3] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 3))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 3))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[3] = out0;
+        }
+        if(slk * 16 + kl * 8 + 4 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 4] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 4))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 4))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[4] = out0;
+        }
+        if(slk * 16 + kl * 8 + 5 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 5] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 5))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 5))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[5] = out0;
+        }
+        if(slk * 16 + kl * 8 + 6 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 6] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 6))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 6))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[6] = out0;
+        }
+        if(slk * 16 + kl * 8 + 7 < k_seq_len) {
+            auto out0 =  sdata[(kl * 16 + rcl) * 8 + 7] * Vscale;
+            #ifdef FLOAT_MASK
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 7))] + out0;
+            #else
+                out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 7))] == 0 ? -FLT_MAX : out0;
+            #endif
+            xy_out[7] = out0;
+        }
+    }
+
+#else
+    const int x = gid.x; // query_seq_len
+    const int y = gid.y; // head_num
+    const int z = gid.z; // key_seq_len
+
+    if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) {
+        return;
+    }
+    int group = param.group;
+    int query_seq_len = param.query_seq_len;
+    int key_seq_len = param.key_seq_len;
+    int head_num = param.head_num;
+    int head_dim = param.head_dim;
+    
+    const int offset = head_num * head_dim;
+    const int offset_head = y * head_dim;
+    const int offset_head_kv = (y / group) * head_dim;
+    const device T* A_offset = input0 + x * offset + offset_head;
+
+    float Vscale = (float)param.scale;
+
+    device const T* B_offset = past_key + z * offset / group + offset_head_kv;
+    const int output_offset = y * query_seq_len * key_seq_len;
+    float out0 = 0.0;
+    
+    for(int i = 0; i < head_dim; ++i){
+        float A = (float)(A_offset[i]);
+        float B = (float)(B_offset[i]);
+        out0 += B * A;
+    }
+    
+    out0 *= Vscale;
+    
+#ifdef FLOAT_MASK
+    out0 = mask[((x + 0) * key_seq_len + (z + 0))] + out0;
+#else
+    out0 = mask[((x + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0;
+#endif
+    output[output_offset + x * key_seq_len + z] = (T)out0;
+#endif
+}
+
+kernel void decode_qk(const device T* input0 [[buffer(0)]],
+    device T* output [[buffer(1)]],
+    device T* past_key [[buffer(2)]],
+#ifdef FLOAT_MASK
+    const device T* mask [[buffer(3)]],
+#else
+    const device int* mask [[buffer(3)]],
+#endif
+    constant Param& param [[buffer(4)]],
+#ifdef SIMD_GROUP_REDUCE
+    uint3 gid[[threadgroup_position_in_grid]],
+    uint  tiisg[[thread_index_in_simdgroup]],
+    uint  sgitg[[simdgroup_index_in_threadgroup]]
+#else
+    uint3 gid[[thread_position_in_grid]]
+#endif
+) {
+    int x = gid.x; // query_seq_len
+    int y = gid.y; // head_num
+    int z = gid.z; // key_seq_len
+
+#ifdef HEAD_NUM_2
+    y = y * 2;
+#endif
+    if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) {
+        return;
+    }
+    int group = param.group;
+
+    int key_seq_len = param.key_seq_len;
+    int head_num = param.head_num;
+    int head_dim = param.head_dim;
+    
+    const int offset = head_num * head_dim;
+    const int offset_head = y * head_dim;
+    const int offset_head_kv = (y / param.group) * head_dim;
+    const device T* A_offset = input0 + x * offset + offset_head;
+    device T* Pastkey_offset = past_key + z * offset / group + offset_head_kv;
+    float Vscale = (float)param.scale;
+    float out = 0.0;
+
+#ifdef HEAD_NUM_2
+    const device T* A_offset_1 = A_offset + head_dim;
+    device T* Pastkey_offset_1 = past_key + z * offset / group + ((y+1) / param.group) * head_dim;
+    float out_1 = 0.0;
+#endif
+
+#ifdef SIMD_GROUP_REDUCE
+    for(int i = tiisg; i < head_dim; i+=SIMD_GROUP_WIDTH){
+        float A = A_offset[i];
+        float B = (float)Pastkey_offset[i];
+        
+        out += A * B;
+    }
+
+#ifdef HEAD_NUM_2
+    if(y + 1 < param.head_num) {
+        for(int i = tiisg; i < head_dim; i+=SIMD_GROUP_WIDTH){
+            float A = A_offset_1[i];
+            float B = (float)Pastkey_offset_1[i];
+            
+            out_1 += A * B;
+        }
+    }
+#endif
+    out = simd_sum(out);
+
+#ifdef HEAD_NUM_2
+    if(y + 1 < param.head_num) {
+        out_1 = simd_sum(out_1);
+        if(tiisg == 1) {
+            out_1 *= Vscale;
+            output[(y+1) * key_seq_len + z] = (T)out_1;
+        }
+    }
+#endif
+    if(tiisg == 0) {
+        out *= Vscale;
+        output[y * key_seq_len + z] = (T)out;
+    }
+
+#else
+    {
+        for(int i = 0; i < head_dim; i++){
+            float A = A_offset[i];
+            float B = (float)Pastkey_offset[i];
+            
+            out += A * B;
+        }
+    }
+    out *= Vscale;
+    output[y * key_seq_len + z] = (T)out;
+
+#ifdef HEAD_NUM_2
+    if(y + 1 < param.head_num) {
+        for(int i = 0; i < head_dim; i++){
+            float A = A_offset_1[i];
+            float B = (float)Pastkey_offset_1[i];
+            
+            out_1 += A * B;
+        }
+        out_1 *= Vscale;
+        output[(y+1) * key_seq_len + z] = (T)out_1;
+    }
+#endif
+
+#endif
+}
+
+)metal";
+
+const char* gCopyPastKV = R"metal(
+#include <metal_stdlib>
+using namespace metal;
+struct Param {
+    int head_count;
+    int q_seq_len;
+    int max_kv_len;
+    int dst_k_offset;
+    int dst_v_offset;
+};
+kernel void copy(const device T* input0 [[buffer(0)]],
+    const device T* input1 [[buffer(1)]],
+    device T* output0 [[buffer(2)]],
+    device T* output1 [[buffer(3)]],
+    constant Param& param [[buffer(4)]],
+    uint3 gid[[thread_position_in_grid]]
+) {
+    const int x = gid.x; // head_num / group * head_dim
+    const int y = gid.y; // q_seq_len
+    if (x >= param.head_count || y >= param.q_seq_len) {
+        return;
+    }
+    const int index = y * param.head_count + x;
+    output0[param.dst_k_offset + index] = input0[index];
+    output1[param.dst_v_offset + x * param.max_kv_len + y] = input1[index];
+}
+)metal";
+
+const char* gMatMulQKV = R"metal(
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+using namespace metal;
+struct Param {
+    int query_seq_len;
+    int key_seq_len;
+    int head_num;
+    int group;
+    int head_dim;
+    float scale;
+    int max_kv_len;
+};
+#define SIMD_GROUP_WIDTH 32
+kernel void prefill_qkv(const device T* input0 [[buffer(0)]],
+    device T* output [[buffer(1)]],
+    device T* past_value [[buffer(2)]],
+    constant Param& param [[buffer(3)]],
+#ifdef SIMD_GROUP_MATRIX
+    uint3 gid[[threadgroup_position_in_grid]],
+    uint tiitg[[thread_index_in_threadgroup]],
+    uint tiisg[[thread_index_in_simdgroup]],
+    uint sgitg[[simdgroup_index_in_threadgroup]]
+#else
+    uint3 gid[[thread_position_in_grid]]
+#endif
+) {
+#ifdef SIMD_GROUP_MATRIX
+    /*
+     // Read:
+     ftype 0~127   ---> input: [M16, K8]
+     ftype 128~255 ---> input: [K8, N16]
+     // Write:
+     ftype 0~255 ---> input: [N2, M2, M8, N8]
+     */
+    
+    simdgroup_float8x8 sga[2];
+    simdgroup_float8x8 sgb[2];
+    simdgroup_float8x8 sgd[4];
+    for (int i = 0; i < 4; i++){
+        sgd[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+
+    int kl = tiitg % 2;// 0~1
+    int rcl = tiitg / 2;// 0~15
+
+    int nl = tiitg % 4;// 0~3
+    int kcl = tiitg / 4;// 0~7
+
+    const int sl = gid.x; // q_seq_len/16 -> M/16
+    const int hm = gid.y; // head_dim/16 -> N/16
+    const int z = gid.z; // head_num
+
+    /** QK:
+     threadgroup: [M16, K8]
+     each thread: K4
+     layout: [B, M, K] -> [B, M/16, M16, K/8, K2, K4]
+     index : [z, sl, rcl, ml, kl, K4]
+     offset: (z * M + sl * 16 + rcl) * K + (0 * 2 + kl) * 4 + 0
+     */
+    /** V:
+     threadgroup: [K8, N16]
+     each thread: N4
+     layout: [K, B/G, N] -> [K/8, K8, B/G, N/16, N4, N4]
+     index : [0, kcl, B/G, hm, nl, 0]
+     offset: ((0 * 8 + kcl) * B/G + z/G) * N + hm * 16 + nl * 4 + 0
+     */
+    /** output:
+     threadgroup: [M16, N16]
+     each thread: N8
+     layout: [M, B, N] -> [M/16, M16, B, N/16, N2, N8]
+     index : [sl, rcl, B, kl, 0]
+     offset: ((sl * 16 + rcl) * B + z) * N + hm * 16 + kl * 8 + 0
+     */
+
+    int group = param.group;
+    int zin = z / group;
+    int q_seq_len = param.query_seq_len;
+    int value_seq_len = param.key_seq_len;
+    int head_num = param.head_num;
+    int head_dim = param.head_dim;
+
+    threadgroup float sdata[256] = {0.f};
+
+    int idx_qk_sl = sl * 16 + rcl < q_seq_len ? (sl * 16 + rcl) : q_seq_len - 1;
+
+    auto A_offset = input0 + (z * q_seq_len + idx_qk_sl) * value_seq_len + (0 * 2 + kl) * 4 + 0;
+    auto B_offset = past_value + (zin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (0 * 8 + kcl);
+    
+
+    for(int i = 0; i < value_seq_len; i += 8){
+        sdata[rcl * 8 + kl * 4 + 0] = (i + kl * 4 + 0 < value_seq_len) ? A_offset[i + 0] : 0.0;
+        sdata[rcl * 8 + kl * 4 + 1] = (i + kl * 4 + 1 < value_seq_len) ? A_offset[i + 1] : 0.0;
+        sdata[rcl * 8 + kl * 4 + 2] = (i + kl * 4 + 2 < value_seq_len) ? A_offset[i + 2] : 0.0;
+        sdata[rcl * 8 + kl * 4 + 3] = (i + kl * 4 + 3 < value_seq_len) ? A_offset[i + 3] : 0.0;
+        
+        sdata[128 + kcl * 16 + nl * 4 + 0] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 0 < head_dim) ? B_offset[i + 0 * param.max_kv_len] : 0.0;
+        sdata[128 + kcl * 16 + nl * 4 + 1] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 1 < head_dim) ? B_offset[i + 1 * param.max_kv_len] : 0.0;
+        sdata[128 + kcl * 16 + nl * 4 + 2] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 2 < head_dim) ? B_offset[i + 2 * param.max_kv_len] : 0.0;
+        sdata[128 + kcl * 16 + nl * 4 + 3] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 3 < head_dim) ? B_offset[i + 3 * param.max_kv_len] : 0.0;
+
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        simdgroup_load(sga[0], (const threadgroup float*)sdata, 8);
+        simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8);
+        
+        simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16);
+        simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16);
+        
+        simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]);
+        simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]);
+        simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]);
+        simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    simdgroup_store(sgd[0], (threadgroup float*)sdata, 8);
+    simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8);
+    simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8);
+    simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8);
+    
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // [N2, M2, M8, N8]
+    auto xy_out = output + ((sl * 16 + rcl) * head_num + z) * head_dim + hm * 16 + kl * 8 + 0;
+    if(sl * 16 + rcl < q_seq_len) {
+        if(hm * 16 + kl * 8 + 0 < head_dim) {
+            xy_out[0] =  sdata[(kl * 16 + rcl) * 8 + 0];
+        }
+        if(hm * 16 + kl * 8 + 1 < head_dim) {
+            xy_out[1] =  sdata[(kl * 16 + rcl) * 8 + 1];
+        }
+        if(hm * 16 + kl * 8 + 2 < head_dim) {
+            xy_out[2] =  sdata[(kl * 16 + rcl) * 8 + 2];
+        }
+        if(hm * 16 + kl * 8 + 3 < head_dim) {
+            xy_out[3] =  sdata[(kl * 16 + rcl) * 8 + 3];
+        }
+        if(hm * 16 + kl * 8 + 4 < head_dim) {
+            xy_out[4] =  sdata[(kl * 16 + rcl) * 8 + 4];
+        }
+        if(hm * 16 + kl * 8 + 5 < head_dim) {
+            xy_out[5] =  sdata[(kl * 16 + rcl) * 8 + 5];
+        }
+        if(hm * 16 + kl * 8 + 6 < head_dim) {
+            xy_out[6] =  sdata[(kl * 16 + rcl) * 8 + 6];
+        }
+        if(hm * 16 + kl * 8 + 7 < head_dim) {
+            xy_out[7] =  sdata[(kl * 16 + rcl) * 8 + 7];
+        }
+    }
+
+#else
+    const int x = gid.x; // kv_seq_len
+    const int y = gid.y; // head_num
+    const int z = gid.z; // head_dim
+    if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) {
+        return;
+    }
+    int group = param.group;
+    int yin = y / group;
+    int q_seq_len = param.query_seq_len;
+    int value_seq_len = param.key_seq_len;
+    int head_num = param.head_num;
+    int head_dim = param.head_dim;
+    const int stride = head_num * head_dim / group;
+    const int offset_head = yin * head_dim + z;
+
+    device const T *A_offset = input0 + (y * q_seq_len + x) * value_seq_len;
+    device const T *B_offset = past_value + offset_head * param.max_kv_len;
+    float out = 0.0;
+    
+    for(int i = 0; i < value_seq_len; ++i){
+        float A0 = (float)A_offset[i];
+        float B = (float)B_offset[i];
+        out += A0 * B;
+    }
+    output[ x * stride * group + (y * head_dim + z)] = out;
+#endif
+}
+
+kernel void decode_qkv(const device T* input0 [[buffer(0)]],
+    device T* output [[buffer(1)]],
+    device T* past_value [[buffer(2)]],
+    constant Param& param [[buffer(3)]],
+#ifdef SIMD_GROUP_REDUCE
+    uint3 gid[[threadgroup_position_in_grid]],
+    uint  tiisg[[thread_index_in_simdgroup]],
+    uint  sgitg[[simdgroup_index_in_threadgroup]]
+#else
+    uint3 gid[[thread_position_in_grid]]
+#endif
+) {
+    const int x = gid.x; // query_seq_len
+    const int y = gid.y; // head_num
+    const int z = gid.z; // head_dim
+    if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) {
+        return;
+    }
+
+    int yin = y / param.group;
+    int value_seq_len = param.key_seq_len;
+
+    int head_dim = param.head_dim;
+
+    const int offset_head = (yin * head_dim + z) * param.max_kv_len;
+
+    device const T *A_offset = input0 + y * value_seq_len;
+    device T *Pastvalue_offset = past_value + offset_head;
+    float out = 0;
+    
+#ifdef SIMD_GROUP_REDUCE
+    for(int i = tiisg; i < value_seq_len; i+=SIMD_GROUP_WIDTH){
+        float A = (float)A_offset[i];
+        float B = (float)Pastvalue_offset[i];
+        
+        out += A * B;
+    }
+    out = simd_sum(out);
+    if(tiisg == 0) {
+        output[(y * head_dim + z)] = (T)out;
+    }
+#else
+    for(int i = 0; i < value_seq_len; i++){
+        float A = (float)A_offset[i];
+        float B = (float)Pastvalue_offset[i];
+        
+        out += A * B;
+    }
+    output[(y * head_dim + z)] = (T)out;
+#endif
+}
+)metal";
+
+#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */
+#endif
+

+ 1 - 1
source/backend/metal/MetalConvolution1x1.hpp

@@ -23,7 +23,7 @@ public:
     virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
     virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder)  override;
 private:
-    MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits);
+    MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits, float scaleCoef);
     id<MTLComputePipelineState> mPipeline;
     std::pair<MTLSize, MTLSize> mThreads;
 };

+ 21 - 6
source/backend/metal/MetalConvolution1x1.mm

@@ -31,11 +31,12 @@ MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op) :
     loadWeight(op, ldInt8Weight);
 }
 
-MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits) : MetalConvolutionCommon(backend, op, bias) {
+MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits, float scaleCoef) : MetalConvolutionCommon(backend, op, bias) {
     mWeight = weight;
     mBias = bias;
     mDequantScaleBias = dequantScale;
     mDequantBits = dequantBits;
+    mScaleCoef = scaleCoef;
 }
 
 
@@ -46,7 +47,7 @@ bool MetalConvolution1x1::onClone(Backend* bn, const Op* op, Execution** dst) {
     if (nullptr == dst) {
         return true;
     }
-    *dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits);
+    *dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits, mScaleCoef);
     return true;
 }
 
@@ -72,12 +73,26 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
     auto context = (__bridge MNNMetalContext *)backend->context();
     int blockSize = 1;
     if (mDequantScaleBias.get()) {
-        blockSize = (int)(mDequantScaleBias->usize() /sizeof(float) / oc_4 / 2 / 4);
+        int bytes = sizeof(float);
+        if(backend->useFp16InsteadFp32()) {
+            bytes = sizeof(__fp16);
+        }
+        blockSize = (int)(mDequantScaleBias->usize() / bytes / oc_4 / 2 / 4);
     }
     // create const buffer
-    int constants[] = {is, ic_4, ow, oh, os, oc_4, oc, ob, blockSize, mActivationType};
-    mConstBuffer = backend->getConstBuffer(sizeof(constants));
-    ::memcpy(mConstBuffer.contents, constants, sizeof(constants));
+    mConstBuffer = backend->getConstBuffer(sizeof(Param));
+    auto param = (Param *)mConstBuffer.contents;
+    param->input_size = is;
+    param->input_slice = ic_4;
+    param->output_width = ow;
+    param->output_height = oh;
+    param->output_size = os;
+    param->output_slice = oc_4;
+    param->output_channel = oc;
+    param->batch = ob;
+    param->block_size = blockSize;
+    param->activation = mActivationType;
+    param->scale_coef = mScaleCoef;
 
     MetalRuntime* rt = (MetalRuntime *)backend->runtime();
     if (mDequantScaleBias.get()) {

+ 16 - 2
source/backend/metal/MetalConvolutionCommon.hpp

@@ -26,8 +26,21 @@ protected:
 
     virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight = false, bool int4Weight = false);
 
-private:
-
+protected:
+    struct Param {
+        int input_size;
+        int input_slice;
+        int output_width;
+        int output_height;
+        int output_size;
+        int output_slice;
+        int output_channel;
+        int batch;
+        int block_size;
+        int activation;
+        float scale_coef;
+    };
+    
 protected:
     int mKernelX        = 0;
     int mKernelY        = 0;
@@ -42,6 +55,7 @@ protected:
     std::shared_ptr<MNN::Tensor> mBias;
     std::shared_ptr<MNN::Tensor> mDequantScaleBias;
     int mDequantBits;
+    float mScaleCoef;
     id<MTLBuffer> mConstBuffer = nil;
 };
 

+ 37 - 11
source/backend/metal/MetalConvolutionCommon.mm

@@ -97,7 +97,8 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src,
     }
 }
 
-static std::shared_ptr<MNN::Tensor> getDequantScale(const float* scale, int size, MetalBackend *backend, bool asymmetric, int oc) {
+template<typename DType>
+static std::pair<std::shared_ptr<MNN::Tensor>, float> getDequantScale(const float* scale, int size, MetalBackend *backend, bool asymmetric, int oc) {
     int totalCount = 0;
     if (asymmetric) {
         totalCount = size / 2;
@@ -106,15 +107,32 @@ static std::shared_ptr<MNN::Tensor> getDequantScale(const float* scale, int size
     }
     int blockSize = totalCount / oc;
     int alignOutputCount = ALIGN_UP4(oc);
-    std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({alignOutputCount, blockSize, (int)(sizeof(float) * 2)}));
+    std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({alignOutputCount, blockSize, (int)(sizeof(DType) * 2)}));
     bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC);
     if (!res) {
         MNN_ERROR("Buffer allocated error!\n");
-        return nullptr;
+        return std::make_pair(nullptr, 1.0);
     }
     auto buffer0 = MetalBackend::getBuffer(dequantScale.get());
-    auto dst_scale = (float*)((uint8_t*)[buffer0.first contents] + buffer0.second);
+    DType* dst_scale = (DType*)((uint8_t*)[buffer0.first contents] + buffer0.second);
     ::memset(dst_scale, 0, dequantScale->usize());
+    
+    float coef = 1.0;
+    if(std::is_same<DType, __fp16>::value) {
+        float max_data = 0.0;
+        for (int z=0; z<oc; ++z) {
+            auto srcZ = scale + z * blockSize * 2;
+            for (int bi=0; bi<blockSize; ++bi) {
+                float s = fabs(srcZ[2*bi+1]);
+                float b = fabs(srcZ[2*bi+0]);
+                float temp = ALIMAX(s, b);
+                if(temp > max_data) {
+                    max_data = temp;
+                }
+            }
+        }
+        coef = 65504.0 / max_data;
+    }
     if (asymmetric) {
         for (int z=0; z<oc; ++z) {
             int zo = z / 4;
@@ -125,8 +143,8 @@ static std::shared_ptr<MNN::Tensor> getDequantScale(const float* scale, int size
             for (int bi=0; bi<blockSize; ++bi) {
                 float s = srcZ[2*bi+1];
                 float b = srcZ[2*bi+0];
-                dstSZ[bi * 8] = s;
-                dstBZ[bi * 8] = b;
+                dstSZ[bi * 8] = (DType)(s * coef);
+                dstBZ[bi * 8] = (DType)(b * coef);
             }
         }
     } else {
@@ -139,12 +157,12 @@ static std::shared_ptr<MNN::Tensor> getDequantScale(const float* scale, int size
             for (int bi=0; bi<blockSize; ++bi) {
                 float s = srcZ[bi];
                 float b = 0.0f;
-                dstSZ[bi * 8] = s;
+                dstSZ[bi * 8] = (DType)(s * coef);
                 dstBZ[bi * 8] = b;
             }
         }
     }
-    return dequantScale;
+    return std::make_pair(dequantScale, coef);
 }
 void MetalConvolutionCommon::loadWeight(const MNN::Op *op, bool loadWeightInt8) {
     auto conv = op->main_as_Convolution2D();
@@ -166,12 +184,20 @@ void MetalConvolutionCommon::loadWeight(const MNN::Op *op, bool loadWeightInt8)
         ic = size / kw / kh / (oc / group);
     }
 
-    // convert
+    // convert
     if (loadWeightInt8 && qnt->weight.get() != nullptr) {
         auto backend = static_cast<MetalBackend *>(this->backend());
         mWeight = weightTransform(group, oc, ic, kh, kw, (float*)qnt->weight.get(), !qnt->canUseInt4, qnt->canUseInt4);
-        auto dequantParams = getDequantScale(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc);
-        mDequantScaleBias = dequantParams;
+        if(backend->useFp16InsteadFp32()) {
+            auto dequantParams = getDequantScale<__fp16>(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc);
+            mDequantScaleBias = dequantParams.first;
+            mScaleCoef = dequantParams.second;
+        } else {
+            auto dequantParams = getDequantScale<float>(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc);
+            mDequantScaleBias = dequantParams.first;
+            mScaleCoef = dequantParams.second;
+        }
+
         mDequantBits = qnt->canUseInt4 ? 4:8;
     } else if (qnt && qnt->weightFloat.size() > 0) {
         mWeight = weightTransform(group, oc, ic, kh, kw, qnt->weightFloat.get(), false, false);

+ 0 - 9
source/backend/metal/MetalDeconvolution.hpp

@@ -24,16 +24,7 @@ public:
 private:
     bool mDepthwise  = false;
     int mGroup       = 0;
-    int mKernelX     = 0;
-    int mKernelY     = 0;
     PadMode mPadMode = PadMode_CAFFE;
-    int mPadX        = 0;
-    int mPadY        = 0;
-    int mStrideX     = 0;
-    int mStrideY     = 0;
-    int mDilateX     = 0;
-    int mDilateY     = 0;
-    int mActivationType = 0;
 
     const MNN::Op *mOp = nullptr;
 

+ 64 - 44
source/backend/metal/MetalDeconvolution.mm

@@ -14,7 +14,33 @@
 
 #if MNN_METAL_ENABLED
 namespace MNN {
-
+struct deconv_constants {
+    int input_width;
+    int input_height;
+    int input_size;
+    int input_slice;
+    int output_width;
+    int output_height;
+    int output_size;
+    int output_slice;
+    
+    int kernel_x;
+    int kernel_y;
+    int kernel_size;
+    int stride_x;
+    int stride_y;
+    int pad_x;
+    int pad_y;
+    int dilation_x;
+    int dilation_y;
+    
+    int delta_ky;
+    int delta_kx;
+    int delta_iy;
+    int delta_ix;
+    int batch;
+    int activation;
+};
 static int leastCommonMultiple(int m, int n) {
     int a = m, b = n;
     while(a != b){
@@ -130,17 +156,7 @@ MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Me
     auto common  = deconv->common();
     mOp          = op;
     mDepthwise   = op->type() == MNN::OpType_DeconvolutionDepthwise;
-    mGroup       = common->group();
-    mKernelX     = common->kernelX();
-    mKernelY     = common->kernelY();
     mPadMode     = common->padMode();
-    mPadX        = common->padX();
-    mPadY        = common->padY();
-    mStrideX     = common->strideX();
-    mStrideY     = common->strideY();
-    mDilateX     = common->dilateX();
-    mDilateY     = common->dilateY();
-    mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
 
     // forcy downgrade to float like what CPU does
     std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
@@ -167,9 +183,13 @@ MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Me
         mValid = false;
         return;
     }
+    auto weightBuffer = MetalBackend::getBuffer(mWeight.get());
+    auto ptr = (uint8_t*)weightBuffer.first.contents + weightBuffer.second;
     if (mtbn->useFp16InsteadFp32()) {
+        ::memset(ptr, 0, weightSize * sizeof(int16_t));
         weightForDeconv<__fp16>(mWeight, mDepthwise, deconv, qnt.get());
     } else {
+        ::memset(ptr, 0, weightSize * sizeof(float));
         weightForDeconv<float>(mWeight, mDepthwise, deconv, qnt.get());
     }
     mBias = biasForDeconv(backend, deconv, mtbn->useFp16InsteadFp32());
@@ -182,6 +202,24 @@ MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Me
     } else {
         mPipeline = [context pipelineWithName:@"deconv" fp16:mtbn->useFp16InsteadFp32()];
     }
+    mConstBuffer = [context newDeviceBuffer:sizeof(deconv_constants) access:CPUWriteOnly];
+    auto param = (deconv_constants*)mConstBuffer.contents;
+    
+    mGroup       = common->group();
+    param->kernel_x = common->kernelX();
+    param->kernel_y = common->kernelY();
+    param->kernel_size = common->kernelX() * common->kernelY();
+    param->stride_x = common->strideX();
+    param->stride_y = common->strideY();
+    param->dilation_x = common->dilateX();
+    param->dilation_y = common->dilateY();
+    param->activation = common->relu() ? 1 : (common->relu6() ? 2 : 0);
+    auto deltaKy = leastCommonMultiple(common->dilateY(), common->strideY()) / common->dilateY();
+    auto deltaKx = leastCommonMultiple(common->dilateX(), common->strideX()) / common->dilateX();
+    param->delta_kx = deltaKx;
+    param->delta_ky = deltaKy;
+    param->delta_iy = deltaKy * common->dilateY() / common->strideY();
+    param->delta_ix = deltaKx * common->dilateX() / common->strideX();
 }
 
 ErrorCode MetalDeconvolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
@@ -197,46 +235,28 @@ ErrorCode MetalDeconvolution::onResize(const std::vector<Tensor *> &inputs, cons
     const int padY = pad.second;
 
     // const buffer
-    auto deltaKy = leastCommonMultiple(mDilateY, mStrideY) / mDilateY;
-    auto deltaKx = leastCommonMultiple(mDilateX, mStrideX) / mDilateX;
-
-    int consts[] = {
-        iw,
-        ih,
-        iw * ih,
-        iz,
-        ow,
-        oh,
-        ow * oh,
-        oz,
-        mKernelX,
-        mKernelY,
-        mKernelX * mKernelY,
-        mStrideX,
-        mStrideY,
-        padX,
-        padY,
-        mDilateX,
-        mDilateY,
-        deltaKy,
-        deltaKx,
-        deltaKy * mDilateY / mStrideY,
-        deltaKx * mDilateX / mStrideX,
-        1,
-        ob,
-        mActivationType
-    };
-    mConstBuffer = [context newDeviceBuffer:sizeof(consts) bytes:consts access:CPUWriteOnly];
+    auto param = (deconv_constants*)mConstBuffer.contents;
+    param->input_width = iw;
+    param->input_height = ih;
+    param->input_size = iw * ih;
+    param->input_slice = iz;
+    param->output_width = ow;
+    param->output_height = oh;
+    param->output_size = ow * oh;
+    param->output_slice = oz;
+    param->batch = ob;
+    param->pad_x = padX;
+    param->pad_y = padY;
 
     mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger) ow, (NSUInteger)oh, (NSUInteger)oz * ob)];
     return NO_ERROR;
 }
 
 void MetalDeconvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
-        auto input = inputs[0], output = outputs[0];
+    auto input = inputs[0], output = outputs[0];
     [encoder setComputePipelineState:mPipeline];
-    [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
-    [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
+    MetalBackend::setTensor(input, encoder, 0);
+    MetalBackend::setTensor(output, encoder, 1);
     [encoder setBuffer:mConstBuffer offset:0 atIndex:2];
     MetalBackend::setTensor(mWeight.get(), encoder, 3);
     MetalBackend::setTensor(mBias.get(), encoder, 4);

+ 29 - 28
source/backend/metal/shader/MetalConvolution1x1.metal

@@ -32,6 +32,7 @@ struct conv1x1_constants {
     int batch;
     int block_size;
     conv_activation_type activation;
+    float scale_coef;
 };
 
 kernel void conv1x1_g1z4(const device ftype4 *in            [[buffer(0)]],
@@ -76,7 +77,7 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in            [[buffer(0)]],
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device MNN::char4x4 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[thread_position_in_grid]]) {
     if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
 
@@ -90,8 +91,8 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in            [[buffer(0)]],
     int computeSize = min(cst.output_size - rx, CONV_UNROLL);
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
-        FLOAT4 bs0 = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]);
-        FLOAT4 bs1 = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]);
+        FLOAT4 bs0 = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 bs1 = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         FLOAT4 scale = bs0;
         FLOAT4 dequant_bias = bs1;
         int zmin = bi * block;
@@ -127,7 +128,7 @@ kernel void conv1x1_gemm_16x16_w4(const device ftype4 *in            [[buffer(0)
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device uchar2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[threadgroup_position_in_grid]],
                             uint                  tiitg[[thread_index_in_threadgroup]],
                             uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -161,8 +162,8 @@ kernel void conv1x1_gemm_16x16_w4(const device ftype4 *in            [[buffer(0)
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
         // [N/4, cst.block_size, 2/*scale_bias*/, N4]
-        FLOAT4 scale = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 1]);
+        FLOAT4 scale = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
 
@@ -220,7 +221,7 @@ kernel void conv1x1_gemm_32x16_w4(const device ftype4 *in            [[buffer(0)
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device uchar2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[threadgroup_position_in_grid]],
                             uint                  tiitg[[thread_index_in_threadgroup]],
                             uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -258,8 +259,8 @@ kernel void conv1x1_gemm_32x16_w4(const device ftype4 *in            [[buffer(0)
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
         // [N/4, cst.block_size, 2/*scale_bias*/, N4]
-        FLOAT4 scale = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 1]);
+        FLOAT4 scale = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (idx_n4 * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
 
@@ -324,7 +325,7 @@ kernel void conv1x1_gemm_16x32_w4(const device ftype4 *in            [[buffer(0)
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device uchar2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[threadgroup_position_in_grid]],
                             uint                  tiitg[[thread_index_in_threadgroup]],
                             uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -360,10 +361,10 @@ kernel void conv1x1_gemm_16x32_w4(const device ftype4 *in            [[buffer(0)
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
         // [N/4, cst.block_size, 2/*scale_bias*/, N4]
-        FLOAT4 scale0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 1]);
-        FLOAT4 scale1 = FLOAT4(dequantScale[2 * (idx_n41 * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias1 = FLOAT4(dequantScale[2 * (idx_n41 * cst.block_size + bi) + 1]);
+        FLOAT4 scale0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
+        FLOAT4 scale1 = FLOAT4(dequantScale[2 * (idx_n41 * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias1 = FLOAT4(dequantScale[2 * (idx_n41 * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
 
@@ -434,7 +435,7 @@ kernel void conv1x1_gemm_32x64_w4(const device ftype2 *in            [[buffer(0)
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device uchar2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[threadgroup_position_in_grid]],
                             uint                  tiitg[[thread_index_in_threadgroup]],
                             uint                  tiisg[[thread_index_in_simdgroup]],
@@ -494,8 +495,8 @@ kernel void conv1x1_gemm_32x64_w4(const device ftype2 *in            [[buffer(0)
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
         // [N/4, cst.block_size, 2/*scale_bias*/, N4]
-        FLOAT4 scale0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 1]);
+        FLOAT4 scale0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[2 * (idx_n40 * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
 
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
@@ -566,7 +567,7 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in            [[buffer(0)]],
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device MNN::uchar4x2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid                          [[thread_position_in_grid]]) {
     if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
 
@@ -580,8 +581,8 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in            [[buffer(0)]],
     int computeSize = min(cst.output_size - rx, CONV_UNROLL);
     int block = (cst.input_slice + cst.block_size - 1) / cst.block_size;
     for (int bi=0; bi<cst.block_size; ++bi) {
-        FLOAT4 scale = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]);
+        FLOAT4 scale = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
         for (int z = zmin; z < zmax; z++) {
@@ -621,7 +622,7 @@ kernel void conv1x1_gemv_g8_w4(const device ftype4 *in            [[buffer(0)]],
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device MNN::uchar4x2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid[[threadgroup_position_in_grid]],
                             uint  tiisg[[thread_index_in_simdgroup]],
                             uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -647,8 +648,8 @@ kernel void conv1x1_gemv_g8_w4(const device ftype4 *in            [[buffer(0)]],
     int outer_index  = (tiisg) / middle_step;
     
     for (int bi= outer_index; bi<cst.block_size; bi += outer_step) {
-        FLOAT4 scale = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]);
-        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]);
+        FLOAT4 scale = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias = FLOAT4(dequantScale[2 * (uz * cst.block_size + bi) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
         for (int z = zmin + middle_index; z < zmax; z += middle_step) {
@@ -683,7 +684,7 @@ kernel void conv1x1_gemv_g16_w4(const device ftype4 *in            [[buffer(0)]]
                             constant conv1x1_constants& cst    [[buffer(2)]],
                             const device MNN::uchar4x2 *wt      [[buffer(3)]],
                             const device ftype4 *biasTerms     [[buffer(4)]],
-                            const device float4 *dequantScale  [[buffer(5)]],
+                            const device ftype4 *dequantScale  [[buffer(5)]],
                             uint3 gid[[threadgroup_position_in_grid]],
                             uint  tiisg[[thread_index_in_simdgroup]],
                             uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -712,10 +713,10 @@ kernel void conv1x1_gemv_g16_w4(const device ftype4 *in            [[buffer(0)]]
     
     for (int bi= outer_index; bi<cst.block_size; bi += outer_step) {
         const int quant_offset = 2 * (uz * cst.block_size + bi);
-        FLOAT4 scale0 = FLOAT4(dequantScale[quant_offset + 0]);
-        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[quant_offset + 1]);
-        FLOAT4 scale1 = FLOAT4(dequantScale[quant_offset + (cst.block_size << 1)]);
-        FLOAT4 dequant_bias1 = FLOAT4(dequantScale[quant_offset + (cst.block_size << 1) + 1]);
+        FLOAT4 scale0 = FLOAT4(dequantScale[quant_offset + 0]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias0 = FLOAT4(dequantScale[quant_offset + 1]) / (FLOAT)cst.scale_coef;
+        FLOAT4 scale1 = FLOAT4(dequantScale[quant_offset + (cst.block_size << 1)]) / (FLOAT)cst.scale_coef;
+        FLOAT4 dequant_bias1 = FLOAT4(dequantScale[quant_offset + (cst.block_size << 1) + 1]) / (FLOAT)cst.scale_coef;
         int zmin = bi * block;
         int zmax = min(zmin + block, cst.input_slice);
         for (int z = zmin + middle_index; z < zmax; z += middle_step) {

+ 3 - 6
source/backend/metal/shader/MetalDeconvolution.metal

@@ -8,7 +8,6 @@ struct deconv_constants {
     int output_height;
     int output_size;
     int output_slice;
-    
     int kernel_x;
     int kernel_y;
     int kernel_size;
@@ -18,12 +17,10 @@ struct deconv_constants {
     int pad_y;
     int dilation_x;
     int dilation_y;
-    
     int delta_ky;
     int delta_kx;
     int delta_iy;
     int delta_ix;
-    int has_bias;
     int batch;
     conv_activation_type activation;
 };
@@ -77,8 +74,8 @@ kernel void deconv_depthwise(const device ftype4 *in        [[buffer(0)]],
                              const device ftype4 *biasTerms [[buffer(4)]],
                              uint3 gid                    [[thread_position_in_grid]]) {
     if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
-    
-    FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z / cst.batch)]);
+    int oz = (int)gid.z / cst.batch;
+    FLOAT4 result = FLOAT4(biasTerms[oz]);
     
     int oy = (int)gid.y + cst.pad_y;
     int ox = (int)gid.x + cst.pad_x;
@@ -95,7 +92,7 @@ kernel void deconv_depthwise(const device ftype4 *in        [[buffer(0)]],
         int min_iy = (oy - max_ky * cst.dilation_y) / cst.stride_y;
         int min_ix = (ox - max_kx * cst.dilation_x) / cst.stride_x;
         
-        auto z_wt = wt + (int)gid.z * cst.kernel_size;
+        auto z_wt = wt + oz * cst.kernel_size;
         auto z_in = in + (int)gid.z * cst.input_size;
         for (auto ky = max_ky, iy = min_iy; ky >= min_ky; ky -= cst.delta_ky, iy += cst.delta_iy) {
             for (auto kx = max_kx, ix = min_ix; kx >= min_kx; kx -= cst.delta_kx, ix += cst.delta_ix) {

+ 74 - 0
source/backend/opencl/core/BufferConvertor.cpp

@@ -574,6 +574,80 @@ bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime
     return true;
 }
 
+bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int memType, bool toDevice, bool toHost) {
+    std::set<std::string> buildOptions;
+    auto srcDimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat;
+    auto dstDimensionFormat = TensorUtils::getDescribe(output)->dimensionFormat;
+    if(runtime->getGpuMemType() == IMAGE){
+        buildOptions.emplace("-DUSE_IMAGE");
+    }
+    
+    buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(srcDimensionFormat));
+    buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(dstDimensionFormat));
+    std::vector<int> outputShape;
+    std::shared_ptr<KernelWrap> kernelW;
+    if(toDevice){
+        buildOptions.emplace("-DSHARED_TO_CL");
+        kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, nullptr, output);
+        outputShape = tensorShapeFormat(output);
+    } else if(toHost){
+        buildOptions.emplace("-DCL_TO_SHARED");
+        kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, input, nullptr);
+        outputShape = tensorShapeFormat(input);
+    }else{
+        MNN_PRINT("convertGLMemBetweenCLmem only support toDevice or toHost!\n");
+        return false;
+    }
+    
+    int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W
+    uint32_t gws[3] = {static_cast<uint32_t>(UP_DIV(shape[3], 4)),
+                                  static_cast<uint32_t>(UP_DIV(shape[1], 4)),
+                                  static_cast<uint32_t>(shape[0] * shape[2])};
+    auto Kernel = kernelW->get();
+    uint32_t idx = 0;
+    cl_int ret = CL_SUCCESS;
+    ret |= Kernel.setArg(idx++, gws[0]);
+    ret |= Kernel.setArg(idx++, gws[1]);
+    ret |= Kernel.setArg(idx++, gws[2]);
+    if(toDevice){
+        ret |= Kernel.setArg(idx++, *((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(input))->getMem());
+    }else{
+        if(runtime->getGpuMemType() == IMAGE) {
+            ret |= Kernel.setArg(idx++, openCLImage(input));
+        }
+        else {
+            ret |= Kernel.setArg(idx++, openCLBuffer(input));
+        }
+    }
+    if (toHost){
+        ret |= Kernel.setArg(idx++, *((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(output))->getMem());
+    }else{
+        if(runtime->getGpuMemType() == IMAGE) {
+            ret |= Kernel.setArg(idx++, openCLImage(output));
+        } else {
+            ret |= Kernel.setArg(idx++, openCLBuffer(output));
+        }
+    }
+    ret |= Kernel.setArg(idx++, sizeof(shape), shape);
+    MNN_CHECK_CL_SUCCESS(ret, "setArg glmem_convert");
+    
+    const uint32_t maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(kernelW));
+    const std::vector<uint32_t> lws = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1};
+    cl::Event event;
+    cl_int res;
+    std::vector<uint32_t> roundUpGroupWorkSize(lws.size());
+    for (size_t i = 0; i < lws.size(); ++i) {
+        roundUpGroupWorkSize[i] = ROUND_UP(gws[i], lws[i]);
+    }
+    
+    res = runtime->commandQueue().enqueueNDRangeKernel(Kernel, cl::NullRange,
+                                                       cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1], roundUpGroupWorkSize[2]),
+                                                       cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
+    event.wait();
+    MNN_CHECK_CL_SUCCESS(res, "glmem_convert");
+    return true;
+}
+
 } // namespace OpenCL
 } // namespace MNN
 #endif /* MNN_OPENCL_BUFFER_CLOSED */

+ 2 - 0
source/backend/opencl/core/BufferConvertor.hpp

@@ -14,6 +14,7 @@
 #include "core/Macro.h"
 #include <MNN/Tensor.hpp>
 #include "backend/opencl/core/OpenCLRunningUtils.hpp"
+#include "backend/opencl/core/OpenCLBackend.hpp"
 
 namespace MNN {
 namespace OpenCL {
@@ -33,6 +34,7 @@ bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *outpu
 #endif
 
 bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, bool toDevice, bool toHost, bool needWait = false, bool svmFlag = false);
+bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int memType, bool toDevice, bool toHost);
                                        
 class BufferConvertor {
 public:

+ 0 - 1
source/backend/opencl/core/BufferPool.cpp

@@ -28,7 +28,6 @@ cl::Buffer* BufferPool::alloc(size_t size, bool separate) {
         return nullptr;
     }
     mAllBuffer.insert(std::make_pair(node->buffer.get(), node));
-
     return node->buffer.get();
 }
 

+ 67 - 85
source/backend/opencl/core/OpenCLBackend.cpp

@@ -333,8 +333,8 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
     if(mOpenCLRuntime->getGpuMemType() == BUFFER) {
         size_t size;
         float typeSize = getBytes(nativeTensor);
-        if (nativeTensor->dimensions() >= 2) {
-            auto alignC = ROUND_UP(C, 8);
+        if (MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(nativeTensor)->dimensionFormat && nativeTensor->dimensions() >= 2) {
+            auto alignC = ROUND_UP(C, 4);
             // increment of height and width
             auto hR = ROUND_UP(H + 3, 4) - H;
             auto wR = ROUND_UP(W + 3, 4) - W;
@@ -353,7 +353,6 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
         }
         // Align when int4 memory
         size = ROUND_UP(size, 2);
-        
         if (storageType == DYNAMIC_SEPERATE) {
             auto buffer = mBufferPool->alloc(size*typeSize, true);
             ((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer;
@@ -593,32 +592,53 @@ bool OpenCLBackend::isCreateError() const {
     return mIsCreateError;
 }
 
-void OpenCLBackend::_allocHostBuffer(int length, const Tensor* srcTensor) const {
+bool OpenCLBackend::_allocHostBuffer(int length, const Tensor* srcTensor) const {
     auto memType = srcTensor->buffer().flags;
-    if (nullptr != mHostBuffer.second && length <= mHostBuffer.first && memType != MNN_FORWARD_OPENCL && memType != MNN_FORWARD_OPENGL) {
-        return;
-    }
-    if(memType == MNN_FORWARD_OPENCL){
-        mDeviceBuffer = (cl::Buffer*)srcTensor->buffer().device;
+    if (nullptr != mHostBuffer.second && length <= mHostBuffer.first && memType != MNN_MEMORY_AHARDWAREBUFFER) {
+        return true;
     }
+    cl_int error;
 #ifdef  __ANDROID__
-    else if(memType == MNN_FORWARD_OPENGL && mOpenCLRuntime->isSupportGL()){
-        cl_int error;
-        mDeviceTexture.reset(new cl::ImageGL(mOpenCLRuntime->context(), CL_MEM_READ_WRITE, GL_TEXTURE_2D, 0, (cl_GLuint)srcTensor->buffer().device, &error));
-        std::vector<cl::Memory> map = {*mDeviceTexture.get()};
-        mOpenCLRuntime->commandQueue().enqueueAcquireGLObjects(&map, NULL);
-    }
+    if(MNN_MEMORY_AHARDWAREBUFFER == memType){
+        if (mOpenCLRuntime->isSupportAHD()){
+            CLSharedMemReleaseBuffer *sharedMem = (CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(srcTensor);
+            if(sharedMem == nullptr || (sharedMem != nullptr && srcTensor->buffer().device != sharedMem->getSharedId())){
+                if(mOpenCLRuntime->getGpuType() == MALI){
+                    const cl_import_properties_arm properties[] = {CL_IMPORT_TYPE_ARM, CL_IMPORT_TYPE_ANDROID_HARDWARE_BUFFER_ARM, 0};
+                    Backend::MemObj* SharedTmp = new CLSharedMemReleaseBuffer(srcTensor->buffer().device, new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)CL_MEM_READ_WRITE, properties, (void*)srcTensor->buffer().device, CL_IMPORT_MEMORY_WHOLE_ALLOCATION_ARM, &error));
+                    TensorUtils::setSharedMem(srcTensor, SharedTmp);
+                }else if(mOpenCLRuntime->getGpuType() == ADRENO){
+                    cl_mem_ahardwarebuffer_host_ptr myAHBmem = {0};
+                    myAHBmem.ext_host_ptr.allocation_type = CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM;
+                    myAHBmem.ext_host_ptr.host_cache_policy = CL_MEM_HOST_WRITEBACK_QCOM;
+                    myAHBmem.ahb_ptr = (AHardwareBuffer*)srcTensor->buffer().device;
+                    Backend::MemObj* SharedTmp = new CLSharedMemReleaseBuffer(srcTensor->buffer().device, new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_USE_HOST_PTR | CL_MEM_EXT_HOST_PTR_QCOM), 0, &myAHBmem, &error));
+                    TensorUtils::setSharedMem(srcTensor, SharedTmp);
+                } else{
+                    MNN_ERROR("This device not support AHardWareBuffer\n");
+                    return false;
+                }
+                if (error != CL_SUCCESS) {
+                    MNN_ERROR("Alloc mAHardWareBuffer error, code:%d \n", error);
+                    return false;
+                }
+            }
+        } else{
+            MNN_ERROR("This device not support AHardWareBuffer\n");
+            return false;
+        }
+    } else
 #endif
-    else{
+    {
         MNN_ASSERT(length > 0);
-        cl_int res;
         mHostBuffer.first = length;
-        mHostBuffer.second.reset(new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR), (size_t)length, NULL, &res));
-        if (nullptr == mHostBuffer.second.get() || res != CL_SUCCESS) {
-            MNN_ERROR("Alloc mHostBuffer %d error, code:%d \n", length, res);
-            return;
+        mHostBuffer.second.reset(new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR), (size_t)length, NULL, &error));
+        if (nullptr == mHostBuffer.second.get() || error != CL_SUCCESS) {
+            MNN_ERROR("Alloc mHostBuffer %d error, code:%d \n", length, error);
+            return false;
         }
     }
+    return true;
 }
 
 void OpenCLBackend::copyFromDeviceInt8(const Tensor* srcTensor, const Tensor* dstTensor) const{
@@ -674,15 +694,15 @@ int OpenCLBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTe
 }
 
 void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, bool svmFlag, int memtype) const {
+#ifdef  __ANDROID__
+    if(MNN_MEMORY_AHARDWAREBUFFER == memtype){
+        convertBetweenAHDandCLmem(const_cast<Tensor*>(srcTensor), const_cast<Tensor*>(dstTensor), mOpenCLRuntime.get(), memtype, false, true);
+        return;
+    }
+#endif
 #ifndef MNN_OPENCL_BUFFER_CLOSED
     if(mOpenCLRuntime->getGpuMemType() == BUFFER)
     {
-        if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){
-            OpenCL::convertNC4HW4BufferToImage(srcTensor, const_cast<Tensor*>(dstTensor), mOpenCLRuntime.get(), false, svmFlag);
-            std::vector<cl::Memory> map = {openCLImage(dstTensor)};
-            mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL);
-            return;
-        }
 #ifdef MNN_SUPPORT_INTEL_SUBGROUP
         int cPack = TensorUtils::getTensorChannelPack(srcTensor);
         if (cPack == 16 && mOpenCLRuntime->isSupportedIntelSubgroup()) {
@@ -710,17 +730,6 @@ void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTens
     else
 #endif /* MNN_OPENCL_BUFFER_CLOSED */
     {
-        if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){
-            std::vector<int> bufferShape = MNN::OpenCL::tensorShapeFormat(srcTensor);
-
-            mOpenCLRuntime.get()->commandQueue().enqueueCopyImage(
-                    openCLImage(srcTensor), openCLImage(dstTensor),
-                    {0, 0, 0}, {0, 0, 0},
-                    {(size_t)bufferShape[2]* UP_DIV(bufferShape[3], 4), (size_t)bufferShape[0]*bufferShape[1], 1});
-            std::vector<cl::Memory> map = {openCLImage(dstTensor)};
-            mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL);
-            return;
-        }
         switch (data_format) {
             case MNN_DATA_FORMAT_NHWC:
                 OpenCL::convertImageToNHWCBuffer(srcTensor, const_cast<Tensor*>(dstTensor), mOpenCLRuntime.get(), false, svmFlag);
@@ -748,8 +757,7 @@ void OpenCLBackend::copyFromDevice(const Tensor* srcTensor, const Tensor* dstTen
                        && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1)
                        && MNN::MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat
                        && (getDataType(srcTensor) == getDataType(dstTensor))
-                       && memType != MNN_FORWARD_OPENCL 
-                       && memType != MNN_FORWARD_OPENGL;
+                       && memType != MNN_MEMORY_AHARDWAREBUFFER;
     if (mOpenCLRuntime->isSupportedFP16()) { // Fp16
         if (dstTensor->getType().code == halide_type_float) {
             directCopy = false;
@@ -792,15 +800,15 @@ void OpenCLBackend::copyFromDevice(const Tensor* srcTensor, const Tensor* dstTen
 
 void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, bool svmFlag, int memtype) const {
     // Format: Host -> OpenCL
+#ifdef  __ANDROID__
+    if(MNN_MEMORY_AHARDWAREBUFFER == memtype){
+        convertBetweenAHDandCLmem(const_cast<Tensor*>(srcTensor), const_cast<Tensor*>(dstTensor), mOpenCLRuntime.get(), memtype, true, false);
+        return;
+    }
+#endif
     #ifndef MNN_OPENCL_BUFFER_CLOSED
     if(mOpenCLRuntime->getGpuMemType() == BUFFER)
     {
-        if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){
-            OpenCL::convertImageToNC4HW4Buffer(srcTensor, const_cast<Tensor*>(dstTensor),mOpenCLRuntime.get(), false, svmFlag);
-            std::vector<cl::Memory> map = {openCLImage(srcTensor)};
-            mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL);
-            return;
-        }
 #ifdef MNN_SUPPORT_INTEL_SUBGROUP
         int cPack = TensorUtils::getTensorChannelPack(dstTensor);
         if (cPack == 16 && mOpenCLRuntime->isSupportedIntelSubgroup()) {
@@ -821,17 +829,6 @@ void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor
     else
     #endif /* MNN_OPENCL_BUFFER_CLOSED */
     {
-        if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){
-            std::vector<int> bufferShape = MNN::OpenCL::tensorShapeFormat(dstTensor);
-
-            mOpenCLRuntime.get()->commandQueue().enqueueCopyImage(
-                    openCLImage(srcTensor), openCLImage(dstTensor),
-                    {0, 0, 0}, {0, 0, 0},
-                    {(size_t)bufferShape[2]* UP_DIV(bufferShape[3], 4), (size_t)bufferShape[0]*bufferShape[1], 1});
-            std::vector<cl::Memory> map = {openCLImage(srcTensor)};
-            mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL);
-            return;
-        }
         if (MNN_DATA_FORMAT_NHWC == data_format) {
             OpenCL::convertNHWCBufferToImage(srcTensor, const_cast<Tensor*>(dstTensor), mOpenCLRuntime.get(), false, svmFlag);
         } else if (MNN_DATA_FORMAT_NCHW == data_format) {
@@ -868,8 +865,7 @@ void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTenso
                        && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1)
                        && MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat
                        && (getDataType(srcTensor) == getDataType(dstTensor))
-                       && memType != MNN_FORWARD_OPENCL
-                       && memType != MNN_FORWARD_OPENGL;
+                       && memType != MNN_MEMORY_AHARDWAREBUFFER;
     if (mOpenCLRuntime->isSupportedFP16()) { // Fp16
         if (dstTensor->getType().code == halide_type_float) {
             directCopy = false;
@@ -901,15 +897,13 @@ void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTenso
     #else
     auto res = mOpenCLRuntime->commandQueue().enqueueWriteBuffer(*mHostBuffer.second, CL_TRUE, 0, needSize, hostPtr);
     if(res != CL_SUCCESS) {
-	MNN_ERROR("OpenCL enqueue write error:%d\n", res);
-	return;
+        MNN_ERROR("OpenCL enqueue write error:%d\n", res);
+        return;
     }
     #endif
 
     //Covert format
     mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, srcDimensionFormat, false);
-
-    return;
 }
 
 void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTensor) const{
@@ -918,33 +912,21 @@ void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dst
     if(MNN_FORWARD_CPU == srcMemtype && MNN_FORWARD_CPU == dstMemtype){
         mCLRuntime->copyBetweenDevice(srcTensor, dstTensor);
     } else {
-        const Tensor* copyTensor = MNN_FORWARD_CPU != srcMemtype ? srcTensor : dstTensor;
-        MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(copyTensor)->dimensionFormat;
-        int memType = MNN_FORWARD_CPU != srcMemtype ? srcMemtype : dstMemtype;
-        if(MNN_FORWARD_OPENCL != memType && MNN_FORWARD_OPENGL != memType){
-            MNN_PRINT("Unsupport ForwardType %d for OpenCL backend!\n", memType);
-            return;
-        }
-        if(mOpenCLRuntime->isSupportGL() && MNN_FORWARD_OPENGL == memType){
-            MNN_PRINT("This Device can not find OpenCL GL_EXTENTION function!\n");
+        const Tensor* hostTensor = MNN_FORWARD_CPU != srcMemtype ? srcTensor : dstTensor;
+        const Tensor* deviceTensor = MNN_FORWARD_CPU == srcMemtype ? srcTensor : dstTensor;
+        MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(deviceTensor)->dimensionFormat;
+        
+        bool alloc_error = _allocHostBuffer(0, hostTensor);
+        if(false == alloc_error){
+            MNN_ERROR("Alloc _allocHostBuffer error\n");
             return;
         }
-        _allocHostBuffer(0, copyTensor);
-
-        MNN::Tensor interTensor(copyTensor, copyTensor->getDimensionType(), false);
-        TensorUtils::getDescribe(&interTensor)->dimensionFormat = data_format;
-        if(MNN_FORWARD_OPENCL == memType ){
-            interTensor.buffer().device = (uint64_t)mDeviceBuffer;
-        }else if(MNN_FORWARD_OPENGL == memType){
-            interTensor.buffer().device = (uint64_t)mDeviceTexture.get();
-        }else{
-            interTensor.buffer().device = (uint64_t)mHostBuffer.second.get();
-        }
+        
         //Covert format
         if(MNN_FORWARD_CPU != srcMemtype){
-            mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, data_format, false, srcMemtype);
+            mCLRuntime->convertToDevice(hostTensor, deviceTensor, data_format, false, srcMemtype);
         }else{
-            mCLRuntime->convertFromDevice(srcTensor, (const Tensor*)&interTensor, data_format, false, dstMemtype);
+            mCLRuntime->convertFromDevice(deviceTensor, hostTensor, data_format, false, dstMemtype);
         }
     }
 }

+ 21 - 3
source/backend/opencl/core/OpenCLBackend.hpp

@@ -153,7 +153,7 @@ private:
     void copyToDeviceInt8(const Tensor* srcTensor, const Tensor* dstTensor) const;
     void copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTensor) const;
 
-    void _allocHostBuffer(int length, const Tensor* srcTensor) const;
+    bool _allocHostBuffer(int length, const Tensor* srcTensor) const;
 
     const CLRuntime* mCLRuntime;
 
@@ -171,8 +171,6 @@ private:
     std::shared_ptr<OpenCLRuntime> mOpenCLRuntime;
 
     mutable std::pair<int, std::shared_ptr<cl::Buffer>> mHostBuffer;
-    mutable cl::Buffer *mDeviceBuffer = nullptr;
-    mutable std::shared_ptr<cl::Image> mDeviceTexture;
     BackendConfig::PrecisionMode mPrecision;
     BackendConfig::MemoryMode mMemory;
     bool mIsCreateError{false};
@@ -233,6 +231,26 @@ public:
     }
 };
 
+class CLSharedMemReleaseBuffer : public Backend::MemObj {
+public:
+    CLSharedMemReleaseBuffer(uint64_t sharedId, cl::Buffer *bId) {
+        mSharedId = sharedId;
+        mBuffer = bId;
+    }
+    virtual ~ CLSharedMemReleaseBuffer() {
+        delete mBuffer;
+    }
+    uint64_t getSharedId(){
+        return mSharedId;
+    }
+    cl::Buffer *getMem(){
+        return mBuffer;
+    }
+private:
+    uint64_t mSharedId;
+    cl::Buffer *mBuffer;
+};
+
 } // namespace OpenCL
 } // namespace MNN
 #endif  /* OpenCLBackend_hpp */

+ 32 - 51
source/backend/opencl/core/runtime/OpenCLRuntime.cpp

@@ -159,62 +159,43 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
             }
             const std::string extensions = platforms[0].getInfo<CL_PLATFORM_EXTENSIONS>();
             bool isPriorityHint = (extensions.find("cl_khr_priority_hints") != std::string::npos);
-
+            std::vector<cl_context_properties> context_properties;
+            if(mGpuType == ADRENO && !isPriorityHint){
+                context_properties.push_back(CL_CONTEXT_PERF_HINT_QCOM);
+                context_properties.push_back(CL_PERF_HINT_HIGH_QCOM);
+                context_properties.push_back(CL_CONTEXT_PRIORITY_HINT_QCOM);
+                context_properties.push_back(CL_PRIORITY_HINT_LOW_QCOM);
+                mIsDeviceSupportedLowPower = true;
+            }
+            #ifdef ARM_OPENCL_PRINTF_DEBUG
+            context_properties.push_back(CL_PRINTF_CALLBACK_ARM);
+            context_properties.push_back((cl_context_properties)callback);
+            context_properties.push_back(CL_PRINTF_BUFFERSIZE_ARM);
+            context_properties.push_back(0x1000);
+            #endif
+            std::string deviceextensions = mFirstGPUDevicePtr.get()->getInfo<CL_DEVICE_EXTENSIONS>();
+#ifdef MNN_USE_LIB_WRAPPER
+            mIsSupportAHD = (getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_import_memory_android_hardware_buffer")
+                 && mGpuType == MALI && OpenCLSymbolsOperator::getOpenclSymbolsPtr()->getFuncAddress(platforms[platformId](), "clImportMemoryARM"))
+                 || (mGpuType == ADRENO && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_qcom_android_ahardwarebuffer_host_ptr"));
+#endif
             if(nullptr != contextPtr){
-                if(nullptr != glShared && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_khr_gl_sharing")){
-                    std::vector<cl_context_properties> context_properties;
-                    context_properties.reserve(7);
-                    context_properties.push_back(CL_GL_CONTEXT_KHR);
-                    context_properties.push_back((cl_context_properties)contextPtr);
-                    context_properties.push_back(CL_EGL_DISPLAY_KHR);
-                    context_properties.push_back((cl_context_properties)glShared);
-                    context_properties.push_back(CL_CONTEXT_PLATFORM);
-                    context_properties.push_back((cl_context_properties)platforms[platformId]());
-                    context_properties.push_back(0);
-                    mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res));
-                }
-                else{
-                    mContext = std::shared_ptr<cl::Context>((cl::Context*)contextPtr, [](void* ptr) {
-                        // Do nothing
-                    });
-                }
+                mContext = std::shared_ptr<cl::Context>((cl::Context*)contextPtr, [](void* ptr) {
+                    // Do nothing
+                });
             }else{
-                if(mGpuType == ADRENO && !isPriorityHint){
-                    std::vector<cl_context_properties> context_properties;
-                    context_properties.reserve(5);
-                    context_properties.push_back(CL_CONTEXT_PERF_HINT_QCOM);
-                    context_properties.push_back(CL_PERF_HINT_HIGH_QCOM);
-                    context_properties.push_back(CL_CONTEXT_PRIORITY_HINT_QCOM);
-                    context_properties.push_back(CL_PRIORITY_HINT_LOW_QCOM);
-                    context_properties.push_back(0);
-                    mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res));
-                    mIsDeviceSupportedLowPower = true;
-                }else{
-                    #ifdef ARM_OPENCL_PRINTF_DEBUG
-                    cl_context_properties context_properties[] =
-                    {
-                        CL_CONTEXT_PLATFORM, (cl_context_properties)platforms[platformId](),
-                        CL_PRINTF_CALLBACK_ARM, (cl_context_properties)callback,
-                        CL_PRINTF_BUFFERSIZE_ARM, 0x1000,
-                        0
-                    };
-                    mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties, nullptr, nullptr, &res));
-                    #else
-                    mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), nullptr, nullptr, nullptr, &res));
-                    #endif
-                }
-                
-                MNN_CHECK_CL_SUCCESS(res, "context");
-                if (res != CL_SUCCESS) {
-                    mIsCreateError = true;
-                    return;
-                }
+                context_properties.push_back(0);
+                mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res));
+            }
+            MNN_CHECK_CL_SUCCESS(res, "context");
+            if (res != CL_SUCCESS) {
+                mIsCreateError = true;
+                return;
             }
             
             mIsDeviceSupportedLowPower = (mIsDeviceSupportedLowPower || isPriorityHint);
             
             #ifdef MNN_USE_LIB_WRAPPER
-            mIsSupportGL = !OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isGlError();
             if(isPriorityHint)
             {
                 if(true == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isPropError())
@@ -646,7 +627,7 @@ std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernelWithCache(const std::strin
                 buildOptionsStr += " -DCONVERT_OUTPUT16=convert_int16";
                 buildOptionsStr += " -DWI_DATA=write_imagei";
             } else {
-                MNN_PRINT("opencl input datatype not support, bit:%d\n", output->getType().bits);
+                MNN_PRINT("opencl output datatype not support, bit:%d\n", output->getType().bits);
                 MNN_ASSERT(false);
             }
         } else if(output->getType().code == halide_type_uint){
@@ -668,7 +649,7 @@ std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernelWithCache(const std::strin
                 buildOptionsStr += " -DCONVERT_OUTPUT16=convert_uint16";
                 buildOptionsStr += " -DWI_DATA=write_imageui";
             } else {
-                MNN_PRINT("opencl input datatype not support, bit:%d\n", output->getType().bits);
+                MNN_PRINT("opencl output datatype not support, bit:%d\n", output->getType().bits);
                 MNN_ASSERT(false);
             }
         } else {

+ 4 - 4
source/backend/opencl/core/runtime/OpenCLRuntime.hpp

@@ -110,9 +110,9 @@ public:
         return mCLVersion;
     }
 	uint32_t getPrecisionLevel() const;
-    bool isSupportGL(){
-    	return mIsSupportGL;
-	}
+    bool isSupportAHD(){
+        return mIsSupportAHD;
+    }
 #ifdef MNN_OPENCL_SVM_ENABLE
     cl_device_svm_capabilities getSvmCapabilities() {
         return mSvmCapabilities;
@@ -215,7 +215,7 @@ private:
     bool mSupportDotInt8 = false;
     bool mSupportDotAccInt8 = false;
     bool mSupportedIntelSubgroup = false;
-    bool mIsSupportGL = true;
+    bool mIsSupportAHD = false;
     GpuType mGpuType;
     MaliAr mMaliAr;
     float mCLVersion = 1.0f;

+ 36 - 60
source/backend/opencl/core/runtime/OpenCLWrapper.cpp

@@ -121,12 +121,24 @@ bool OpenCLSymbols::isPropError() {
 bool OpenCLSymbols::isQcomError() {
     return mQcomError;
 }
-
-bool OpenCLSymbols::isGlError() {
-    return mGlError;
+    
+bool OpenCLSymbols::getFuncAddress(cl_platform_id platform, const char *func_name){
+    if(clGetExtensionFunctionAddressForPlatform != nullptr){
+        clImportMemoryARM = reinterpret_cast<clImportMemoryARMFunc>(clGetExtensionFunctionAddressForPlatform(platform, "clImportMemoryARM"));
+        if(clImportMemoryARM == nullptr){
+            return false;
+        }
+    }else if(clGetExtensionFunctionAddress != nullptr){
+        clImportMemoryARM = reinterpret_cast<clImportMemoryARMFunc>(clGetExtensionFunctionAddress("clImportMemoryARM"));
+        if(clImportMemoryARM == nullptr){
+            return false;
+        }
+    } else{
+        return false;
+    }
+    return true;
 }
 
-
 bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
 #if defined(_WIN32)
     handle_ = LoadLibraryA(library_path.c_str());
@@ -203,15 +215,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
     if(func_name == nullptr){ \
         mQcomError = true; \
     }
-
-#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
-    if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
-        func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
-    } \
-    if(func_name == nullptr){ \
-        mGlError = true; \
-    }
-
+    
 #endif
 
     MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs);
@@ -261,10 +265,8 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
     MNN_LOAD_FUNCTION_PTR(clEnqueueCopyImage);
     MNN_LOAD_FUNCTION_PTR(clEnqueueReadImage);
     MNN_LOAD_FUNCTION_PTR(clEnqueueWriteImage);
-    MNN_LOAD_GL_PTR(clCreateFromGLBuffer);
-    MNN_LOAD_GL_PTR(clCreateFromGLTexture);
-    MNN_LOAD_GL_PTR(clEnqueueAcquireGLObjects);
-    MNN_LOAD_GL_PTR(clEnqueueReleaseGLObjects);
+    MNN_LOAD_FUNCTION_PTR(clGetExtensionFunctionAddress);
+    MNN_LOAD_FUNCTION_PTR(clGetExtensionFunctionAddressForPlatform);
 
     MNN_LOAD_PROP_PTR(clCreateCommandQueueWithProperties);
     MNN_LOAD_SVM_PTR(clSVMAlloc);
@@ -671,49 +673,6 @@ cl_int CL_API_CALL clEnqueueCopyImage(cl_command_queue queue,
     return func(queue, src_image, dst_image, src_origin, dst_origin, region, num_events_in_wait_list, event_wait_list, event);
 }
 
-cl_mem CL_API_CALL clCreateFromGLBuffer(cl_context context,
-                                        cl_mem_flags flags,
-                                        cl_GLuint bufobj,
-                                        int *errcode_ret){
-    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLBuffer;
-    MNN_CHECK_NOTNULL(func);
-    return func(context, flags, bufobj, errcode_ret);
-}
-
-cl_mem CL_API_CALL clCreateFromGLTexture(cl_context context,
-                                         cl_mem_flags flags,
-                                         cl_GLenum target,
-                                         cl_GLint miplevel,
-                                         cl_GLuint texture,
-                                         cl_int *errcode_ret){
-    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLTexture;
-    MNN_CHECK_NOTNULL(func);
-    return func(context, flags, target, miplevel, texture, errcode_ret);
-
-}
-
-cl_int CL_API_CALL clEnqueueAcquireGLObjects(cl_command_queue command_queue,
-                                             cl_uint num_objects,
-                                             const cl_mem *mem_objects,
-                                             cl_uint num_events_in_wait_list,
-                                             const cl_event *event_wait_list,
-                                             cl_event *event){
-    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clEnqueueAcquireGLObjects;
-    MNN_CHECK_NOTNULL(func);
-    return func(command_queue, num_objects, mem_objects, num_events_in_wait_list, event_wait_list, event);
-}
-
-cl_int CL_API_CALL clEnqueueReleaseGLObjects(cl_command_queue command_queue,
-                                             cl_uint num_objects,
-                                             const cl_mem *mem_objects,
-                                             cl_uint num_events_in_wait_list,
-                                             const cl_event *event_wait_list,
-                                             cl_event *event){
-    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clEnqueueReleaseGLObjects;
-    MNN_CHECK_NOTNULL(func);
-    return func(command_queue, num_objects, mem_objects, num_events_in_wait_list, event_wait_list, event);
-}
-
 // clCreateCommandQueueWithProperties wrapper
 cl_command_queue CL_API_CALL clCreateCommandQueueWithProperties(cl_context context, cl_device_id device, const cl_queue_properties *properties, cl_int *errcode_ret) {
     auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateCommandQueueWithProperties;
@@ -799,5 +758,22 @@ clEnqueueRecordingSVMQCOM(cl_command_queue command_queue, cl_recording_qcom reco
     return func(command_queue, recording, num_args, arg_array, num_svm_args, arg_svm_array, num_global_offsets, global_offset_array, num_global_workgroups, global_workgroup_array, num_local_workgroups, local_workgroups_array, num_non_arg_objs, non_arg_obj_array, num_events_in_wait_list, event_wait_list, event);
 }
 
+void * CL_API_CALL clGetExtensionFunctionAddress(const char *func_name){
+    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clGetExtensionFunctionAddress;
+    MNN_CHECK_NOTNULL(func);
+    return func(func_name);
+}
+
+void * CL_API_CALL clGetExtensionFunctionAddressForPlatform(cl_platform_id platform, const char *func_name){
+    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clGetExtensionFunctionAddressForPlatform;
+    MNN_CHECK_NOTNULL(func);
+    return func(platform, func_name);
+}
+
+cl_mem CL_API_CALL clImportMemoryARM(cl_context context, cl_mem_flags flags, const cl_import_properties_arm *properties, void *memory, size_t size, cl_int *errcode_ret){
+    auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clImportMemoryARM;
+    MNN_CHECK_NOTNULL(func);
+    return func(context, flags, properties, memory, size, errcode_ret);
+}
 
 #endif //MNN_USE_LIB_WRAPPER

+ 12 - 10
source/backend/opencl/core/runtime/OpenCLWrapper.hpp

@@ -31,6 +31,10 @@
 #endif
 
 #include "CL/cl_ext_qcom.h"
+#include "CL/cl_ext.h"
+#ifdef __ANDROID__
+#include <android/hardware_buffer.h>
+#endif
 
 #define MNN_CHECK_NOTNULL(X) MNN_ASSERT(X != NULL)
 
@@ -53,7 +57,7 @@ public:
     bool isSvmError();
     bool isPropError();
     bool isQcomError();
-    bool isGlError();
+    bool getFuncAddress(cl_platform_id platform, const char *func_name);
     
     using clGetPlatformIDsFunc        = cl_int (CL_API_CALL *)(cl_uint, cl_platform_id *, cl_uint *);
     using clGetPlatformInfoFunc       = cl_int (CL_API_CALL *)(cl_platform_id, cl_platform_info, size_t, void *, size_t *);
@@ -148,10 +152,6 @@ public:
                                                    size_t param_value_size, void *param_value,
                                                    size_t *param_value_size_ret);
     using clGetImageInfoFunc           = cl_int (CL_API_CALL *)(cl_mem, cl_image_info, size_t, void *, size_t *);
-    using clCreateFromGLBufferFunc     = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, cl_GLuint, int *);
-    using clCreateFromGLTextureFunc     = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, cl_GLenum, cl_GLint, cl_GLuint, cl_int*);
-    using clEnqueueAcquireGLObjectsFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_uint, const cl_mem *, cl_uint, const cl_event *, cl_event *);
-    using clEnqueueReleaseGLObjectsFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_uint, const cl_mem *, cl_uint, const cl_event *, cl_event *);
     using clReleaseDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id);
     using clRetainDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id);
 
@@ -176,6 +176,10 @@ public:
                                                      size_t, const cl_offset_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_workgroup_qcom*,
                                                      size_t, const cl_array_kernel_exec_info_qcom*, cl_uint, const cl_event*, cl_event*);
     
+    using clGetExtensionFunctionAddressFunc = void *(CL_API_CALL *)(const char *);
+    using clGetExtensionFunctionAddressForPlatformFunc = void *(CL_API_CALL *)(cl_platform_id, const char *);
+    using clImportMemoryARMFunc = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, const cl_import_properties_arm*, void*, size_t, cl_int*);
+    
 #define MNN_CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr
 
     MNN_CL_DEFINE_FUNC_PTR(clGetPlatformIDs);
@@ -225,10 +229,6 @@ public:
     MNN_CL_DEFINE_FUNC_PTR(clGetImageInfo);
     MNN_CL_DEFINE_FUNC_PTR(clEnqueueReadImage);
     MNN_CL_DEFINE_FUNC_PTR(clEnqueueWriteImage);
-    MNN_CL_DEFINE_FUNC_PTR(clCreateFromGLBuffer);
-    MNN_CL_DEFINE_FUNC_PTR(clCreateFromGLTexture);
-    MNN_CL_DEFINE_FUNC_PTR(clEnqueueAcquireGLObjects);
-    MNN_CL_DEFINE_FUNC_PTR(clEnqueueReleaseGLObjects);
     
     MNN_CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties);
     MNN_CL_DEFINE_FUNC_PTR(clSVMAlloc);
@@ -243,6 +243,9 @@ public:
     MNN_CL_DEFINE_FUNC_PTR(clRetainRecordingQCOM);
     MNN_CL_DEFINE_FUNC_PTR(clEnqueueRecordingQCOM);
     MNN_CL_DEFINE_FUNC_PTR(clEnqueueRecordingSVMQCOM);
+    MNN_CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddress);
+    MNN_CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddressForPlatform);
+    MNN_CL_DEFINE_FUNC_PTR(clImportMemoryARM);
 
 #undef MNN_CL_DEFINE_FUNC_PTR
 
@@ -258,7 +261,6 @@ private:
     bool mPropError{false};
     bool mQcomError{false};
     bool mCL_12Error{false};
-    bool mGlError{false};
 };
 
 class OpenCLSymbolsOperator {

+ 5 - 6
source/backend/opencl/execution/buffer/ConvBufExecution.cpp

@@ -204,7 +204,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
             }
             mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL);
 
-            mResource->mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]}));
+            mResource->mFilter.reset(Tensor::createDevice<float>({filterImageShape[1] * 4 * filterImageShape[0]}));
             mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
             MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
 
@@ -458,8 +458,8 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
             std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
             for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
                 std::set<std::string> buildOption = mResource->mBuildOptions;
-                if(outputShape.at(3) % itemC[knl_idx] != 0){
-                    buildOption.emplace("-DCHANNEL_LEAVE");
+                if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+                    buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
                 }
                 if((outputShape.at(2) % itemW[knl_idx]) != 0){
                     buildOption.emplace("-DBLOCK_LEAVE");
@@ -496,13 +496,12 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
                 }
             }
 
-            std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
             int min_index  = min_cost.second;
             mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
 
             std::set<std::string> buildOption = mResource->mBuildOptions;
-            if(outputShape.at(3) % itemC[min_index] != 0){
-                buildOption.emplace("-DCHANNEL_LEAVE");
+            if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+                buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
             }
             if((outputShape.at(2) % itemW[min_index]) != 0){
                 buildOption.emplace("-DBLOCK_LEAVE");

+ 4 - 4
source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp

@@ -265,8 +265,8 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor
     // MNN_PRINT("Checking kernel %d.\n", knlCheck);
     for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
         std::set<std::string> buildOption = mResource->mBuildOptions;
-        if(outputShape.at(3) % itemC[knl_idx] != 0){
-            buildOption.emplace("-DCHANNEL_LEAVE");
+        if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+            buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
         }
         if((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0){
             buildOption.emplace("-DBLOCK_LEAVE");
@@ -313,8 +313,8 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor
     mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
 
     std::set<std::string> buildOption = mResource->mBuildOptions;
-    if(outputShape.at(3) % itemC[min_index] != 0){
-        buildOption.emplace("-DCHANNEL_LEAVE");
+    if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+        buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
     }
     if((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0){
         buildOption.emplace("-DBLOCK_LEAVE");

+ 10 - 2
source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp

@@ -160,7 +160,11 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector<Tensor *> &input
         std::vector<uint32_t> localWorkSize[total_kernel];
         std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
         for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
-            kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[knl_idx], mResource->mBuildOptions);
+            std::set<std::string> buildOption = mResource->mBuildOptions;
+            if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+                buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+            }
+            kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[knl_idx], buildOption);
             uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
                         
             globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
@@ -196,7 +200,11 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector<Tensor *> &input
         int min_index  = min_cost.second;
         mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
         
-        unit.kernel     = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[min_index], mResource->mBuildOptions);
+        std::set<std::string> buildOption = mResource->mBuildOptions;
+        if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+            buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+        }
+        unit.kernel     = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[min_index], buildOption);
         
         uint32_t idx = 0;
         cl_int ret = CL_SUCCESS;

+ 1 - 1
source/backend/opencl/execution/cl/buffer_convert_buf.cl

@@ -74,7 +74,7 @@ __kernel void buffer_copy_to_buffer(GLOBAL_SIZE_2_DIMS
 #endif
 }
 
-// convert kernel : from buffer(oihw) to image(oc/4 h w , ic oc4)
+// convert kernel : from buffer(oihw) to image(ic, oc/4, h, w, oc4)
 __kernel void conv2d_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
                                             __global const FLOAT *input_ptr,
                                             __private const int output_channel,

+ 28 - 0
source/backend/opencl/execution/cl/conv_2d.cl

@@ -459,6 +459,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
 #if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
         int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
+        // already pack to 16, no need boundry protect
         COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx, dequantScaleOffset + kindex));
         COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
         COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
@@ -476,7 +477,11 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
 
 #if (defined USE_LOW_BIT_WEIGHT_INT8)
         FLOAT16 weightsInt80 = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
+        #ifdef CHANNEL_BOUNDARY_PROTECT
+        FLOAT16 weightsInt81 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
+        #else
         FLOAT16 weightsInt81 = CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
+        #endif
         FLOAT4 weights0 = CONVERT_FLOAT4(weightsInt80.s0123) * scale0 + offset0;
         FLOAT4 weights1 = CONVERT_FLOAT4(weightsInt80.s4567) * scale0 + offset0;
         FLOAT4 weights2 = CONVERT_FLOAT4(weightsInt80.s89ab) * scale0 + offset0;
@@ -541,10 +546,17 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
         weights2 = vload4(weights_width_base + 2, weights + weight_offset);
         weights3 = vload4(weights_width_base + 3, weights + weight_offset);
 
+        #ifdef CHANNEL_BOUNDARY_PROTECT
+        weights4 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base, weights + weight_offset1);
+        weights5 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 1, weights + weight_offset1);
+        weights6 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 2, weights + weight_offset1);
+        weights7 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 3, weights + weight_offset1);
+        #else
         weights4 = vload4(weights_width_base, weights + weight_offset1);
         weights5 = vload4(weights_width_base + 1, weights + weight_offset1);
         weights6 = vload4(weights_width_base + 2, weights + weight_offset1);
         weights7 = vload4(weights_width_base + 3, weights + weight_offset1);
+        #endif
 #else
         weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx));
         weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx));
@@ -1081,10 +1093,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
                 weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
                 weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
                 weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                charWeight0 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
+                charWeight1 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);
+                charWeight2 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);
+                charWeight3 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                
+                #else
                 charWeight0 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
                 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);
                 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);
                 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #endif
                 weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1);
                 weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1);
                 weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1);
@@ -1153,10 +1173,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
                 weights1 = vload4(0, weights+weight_offset+weight_ic_offset);
                 weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2);
                 weights3 = vload4(0, weights+weight_offset+weight_ic_offset*3);
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                charWeight0 =
+                weights4 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset + weight_oc_offset);
+                weights5 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset);
+                weights6 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset);
+                weights7 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset);
+                #else
                 weights4 = vload4(0, weights+weight_offset + weight_oc_offset);
                 weights5 = vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset);
                 weights6 = vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset);
                 weights7 = vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset);
+                #endif
                 weight_offset += 4;
 #else
                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));

+ 107 - 60
source/backend/opencl/execution/cl/conv_2d_buf.cl

@@ -200,25 +200,33 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = out_c_w_idx / out_w_blocks;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
     const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
 
     const int out_w4_idx = mul24(out_w_idx, 4);
-    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1, bias_ptr));
+    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias_ptr));
     COMPUTE_FLOAT4 out1 = out0;
     COMPUTE_FLOAT4 out2 = out0;
     COMPUTE_FLOAT4 out3 = out0;
     
-    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1, bias_ptr));
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr));
+    COMPUTE_FLOAT4 out5 = out4;
+    COMPUTE_FLOAT4 out6 = out4;
+    COMPUTE_FLOAT4 out7 = out4;
+    #else
+    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr));
     COMPUTE_FLOAT4 out5 = out4;
     COMPUTE_FLOAT4 out6 = out4;
     COMPUTE_FLOAT4 out7 = out4;
+    #endif
 
     const int intput_width_idx0 = out_w4_idx;
     int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2;
-    int offset = out_c_idx*8;
+    int offset = out_c_idx_0*4;
     const int inp_add = out_b*out_h*out_w*4;
 
     for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) {
@@ -229,6 +237,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
         COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset));
         COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(3, input+inp_offset));
         
+        // output_channel at least pack to 8, no need boundry protect
         COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset));
         COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset));
         COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack));
@@ -306,7 +315,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
+    const int out_offset = (((out_b_idx + out_c_idx_0*out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
 
     __global FLOAT * _tempoutput = output + out_offset;
     __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w*out_b;
@@ -323,8 +332,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     } else if (remain == 1) {
         vstore4(CONVERT_FLOAT4(out0), 0, _tempoutput);
     }
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx*2+1 >= out_c_block) {
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_block) {
         return;
     }
 #endif
@@ -340,8 +349,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     }
 #else
     vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, _tempoutput);
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx*2+1 >= out_c_block) {
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_block) {
         return;
     }
 #endif
@@ -368,21 +377,26 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = out_c_w_idx / out_w_blocks;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
     const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
     
     const int out_w2_idx = mul24(out_w_idx, 2);
-    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1, bias_ptr));
+    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias_ptr));
     COMPUTE_FLOAT4 out1 = out0;
     
-    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1, bias_ptr));
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr));
+    #else
+    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr));
+    #endif
     COMPUTE_FLOAT4 out5 = out4;
 
     const int intput_width_idx0 = out_w2_idx;
     int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2;
-    int offset = out_c_idx*8;
+    int offset = out_c_idx_0*4;
     const int inp_add = out_b*out_h*out_w*4;
     for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) {
         
@@ -437,7 +451,7 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     out5 = clamp(out5, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
+    const int out_offset = (((out_b_idx + out_c_idx_0*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
 
 
     __global FLOAT * _tempoutput = output + out_offset;
@@ -450,8 +464,8 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     } else if (remain == 1) {
         vstore4(CONVERT_FLOAT4(out0), 0, _tempoutput);
     }
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx*2+1 >= out_c_block) {
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_block) {
         return;
     }
 #endif
@@ -462,8 +476,8 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
     }
 #else
     vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0, out1)), 0, _tempoutput);
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx*2+1 >= out_c_block) {
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_block) {
         return;
     }
 #endif
@@ -1071,16 +1085,21 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx
     const int out_h_idx = (out_b_h_idx % out_h_blocks) << 2;
     
-    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out1 = out0;
     COMPUTE_FLOAT4 out2 = out0;
     COMPUTE_FLOAT4 out3 = out0;
-    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #else
+    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #endif
     COMPUTE_FLOAT4 out5 = out4;
     COMPUTE_FLOAT4 out6 = out4;
     COMPUTE_FLOAT4 out7 = out4;
@@ -1100,12 +1119,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
     const int weight_ic_offset = out_c_blocks * weight_oc_offset;
     const int in_hw_size = in_hw.x * in_hw.y;
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
-        //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
+        //weights  NC4HW4   [ic/4, ic_4, oc/4, kh*kw, oc_4]
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
         const int inp_offset_base = (out_b_idx + in_c_idx * batch) * in_hw.x * in_hw.y * 4;
 
         for(int iy = 0; iy < filter_hw.x; iy++) {
-            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
+            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
             const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y;
             const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y;
             const int in_h2_idx = (iy * dilate_hw.x + in_h2_idx_base) * in_hw.y;
@@ -1142,11 +1161,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
                 out3 = mad(in3.z, weight2, out3);
                 out3 = mad(in3.w, weight3, out3);
 
+                // weight: [ic/4, ic_4, oc/4, kh*kw, oc_4]
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
+                weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
+                weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
+                weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
+                #else
                 weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
                 weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
                 weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
                 weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
-
+                #endif
                 out4 = mad(in0.x, weight0, out4);
                 out4 = mad(in0.y, weight1, out4);
                 out4 = mad(in0.z, weight2, out4);
@@ -1193,7 +1219,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
     out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.x - out_h_idx;
     if(remain >= 4){
@@ -1211,12 +1237,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 4){
         vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset);
         vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset);
@@ -1237,12 +1263,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
     vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out2), 2 * out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out3), 3 * out_hw.y, output+out_offset);
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset);
@@ -1273,16 +1299,21 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx
     const int out_h_idx = (out_b_h_idx % out_h_blocks) << 1;
     
-    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out1 = out0;
-    COMPUTE_FLOAT4 out2 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    COMPUTE_FLOAT4 out2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #else
+    COMPUTE_FLOAT4 out2 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #endif
     COMPUTE_FLOAT4 out3 = out2;
-
+    
     const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
 
     const int in_h0_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
@@ -1298,11 +1329,11 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
     // weight: [ic/4, oc, 4], loop: ic/4
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
         const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4;
 
         for(int iy = 0; iy < filter_hw.x; iy++) {
-            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
+            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
             const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y;
             const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y;
 
@@ -1324,11 +1355,17 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
                 out1 = mad(in1.z, weight2, out1);
                 out1 = mad(in1.w, weight3, out1);
                 
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
+                weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
+                weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
+                weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
+                #else
                 weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
                 weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
                 weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
                 weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
-                
+                #endif
                 out2 = mad(in0.x, weight0, out2);
                 out2 = mad(in0.y, weight1, out2);
                 out2 = mad(in0.z, weight2, out2);
@@ -1357,7 +1394,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
     out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.x - out_h_idx;
     if(remain >= 2){
@@ -1366,12 +1403,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 2){
         vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset);
         vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset);
@@ -1381,12 +1418,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
 #else
     vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset);
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset);
 #endif
@@ -1415,17 +1452,21 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2;
     const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
     const int out_h_idx = out_b_h_idx % out_hw.x;
     
-    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out1 = out0;
     COMPUTE_FLOAT4 out2 = out0;
     COMPUTE_FLOAT4 out3 = out0;
-    
-    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #else
+    COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
+    #endif
     COMPUTE_FLOAT4 out5 = out4;
     COMPUTE_FLOAT4 out6 = out4;
     COMPUTE_FLOAT4 out7 = out4;
@@ -1445,8 +1486,8 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
     const int weight_ic_offset = out_c_blocks * weight_oc_offset;
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
-        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
+        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
 
         for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
             const int inp_offset_base = (((out_b_idx + in_c_idx * batch) * in_hw.x + iy) * in_hw.y + 0) * 4;
@@ -1487,11 +1528,17 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
                 out3 = mad(in3.z, weight2, out3);
                 out3 = mad(in3.w, weight3, out3);
                 
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
+                weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
+                weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
+                weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
+                #else
                 weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset));
                 weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset));
                 weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2));
                 weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3));
-                
+                #endif
                 out4 = mad(in0.x, weight0, out4);
                 out4 = mad(in0.y, weight1, out4);
                 out4 = mad(in0.z, weight2, out4);
@@ -1538,7 +1585,7 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
     out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.y - out_w_idx;
     if(remain >= 4){
@@ -1551,10 +1598,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks)return;
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks)return;
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 4){
         vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset);
     }else if(remain == 3){
@@ -1567,10 +1614,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
     }
 #else
     vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output+out_offset);
-    #ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks)return;
+    #ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks)return;
     #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset);
 #endif
 }

+ 82 - 43
source/backend/opencl/execution/cl/conv_2d_int_buf.cl

@@ -10,7 +10,7 @@
     }
 
 #define MOD_NUM 15
-#ifdef INPUT_CHANNEL_LEAVE
+#ifdef INPUT_CHANNEL_BOUNDARY_PROTECT
     #define PADZEROSVEC(k, channel, data0, data1, data2, data3) \
         data0 = (k << 2) < channel ? data0 : 0; \
         data1 = (k << 2) + 1 < channel ? data1 : 0; \
@@ -674,17 +674,19 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx
     const int out_h_idx = (out_b_h_idx % out_h_blocks) << 2;
     
-    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out0 = bias0;
     COMPUTE_FLOAT4 out1 = bias0;
     COMPUTE_FLOAT4 out2 = bias0;
     COMPUTE_FLOAT4 out3 = bias0;
-    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    // bias align to 8, no need boundry protect
+    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
     COMPUTE_FLOAT4 out4 = bias1;
     COMPUTE_FLOAT4 out5 = bias1;
     COMPUTE_FLOAT4 out6 = bias1;
@@ -706,18 +708,22 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
     const int in_hw_size = in_hw.x * in_hw.y;
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
         int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8;
-        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex));
-        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex));
+        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex));
+        #ifdef CHANNEL_BOUNDARY_PROTECT
+        COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #else
+        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #endif
         COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
         COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
         COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6);
         COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7);
         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
         const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4;
 
         for(int iy = 0; iy < filter_hw.x; iy++) {
-            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
+            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
             const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y;
             const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y;
             const int in_h2_idx = (iy * dilate_hw.x + in_h2_idx_base) * in_hw.y;
@@ -791,10 +797,17 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
                 out3 = mad(in3.w, weight3, out3);
 
 #if (defined USE_LOW_BIT_WEIGHT_INT8)
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset);
+                charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
+                charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
+                charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #else
                 charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset);
                 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
                 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
                 charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #endif
                 weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1;
                 weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1;
                 weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1;
@@ -878,7 +891,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
     out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.x - out_h_idx;
     if(remain >= 4){
@@ -896,12 +909,12 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 4){
         vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset);
         vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset);
@@ -922,12 +935,12 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS
     vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out2), 2 * out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out3), 3 * out_hw.y, output+out_offset);
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset);
     vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset);
@@ -964,15 +977,17 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = out_c_w_idx % out_w_blocks;
     const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx
     const int out_h_idx = (out_b_h_idx % out_h_blocks) << 1;
 
-    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out0 = bias0;
     COMPUTE_FLOAT4 out1 = bias0;
-    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    // bias align to 8, no need boundry protect
+    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
     COMPUTE_FLOAT4 out2 = bias1;
     COMPUTE_FLOAT4 out3 = bias1;
 
@@ -991,18 +1006,22 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
     // weight: [ic/4, oc, 4], loop: ic/4
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
         int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8;
-        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex));
-        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex));
+        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex));
+        #ifdef CHANNEL_BOUNDARY_PROTECT
+        COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #else
+        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #endif
         COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
         COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
         COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6);
         COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7);
         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
         const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4;
 
         for(int iy = 0; iy < filter_hw.x; iy++) {
-            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
+            int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4;
             const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y;
             const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y;
 
@@ -1060,10 +1079,17 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
                 out1 = mad(in1.w, weight3, out1);
                 
 #if (defined USE_LOW_BIT_WEIGHT_INT8)
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset);
+                charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
+                charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
+                charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #else
                 charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset);
                 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
                 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
                 charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #endif
                 weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1;
                 weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1;
                 weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1;
@@ -1128,7 +1154,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
     out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.x - out_h_idx;
     if(remain >= 2){
@@ -1137,12 +1163,12 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 2){
         vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset);
         vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset);
@@ -1152,12 +1178,12 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS
 #else
     vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset);
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks){
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks){
         return;
     }
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset);
     vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset);
 #endif
@@ -1192,17 +1218,19 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
 
     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
 
-    const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1;
+    const int out_c_idx_1 = out_c_idx_0 + 1;
     const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2;
     const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
     const int out_h_idx = out_b_h_idx % out_hw.x;
     
-    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias));
+    COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias));
     COMPUTE_FLOAT4 out0 = bias0;
     COMPUTE_FLOAT4 out1 = bias0;
     COMPUTE_FLOAT4 out2 = bias0;
     COMPUTE_FLOAT4 out3 = bias0;
-    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias));
+    // bias align to 8, no need boundry protect
+    COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias));
     COMPUTE_FLOAT4 out4 = bias1;
     COMPUTE_FLOAT4 out5 = bias1;
     COMPUTE_FLOAT4 out6 = bias1;
@@ -1223,15 +1251,19 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
     const int weight_ic_offset = out_c_blocks * weight_oc_offset;
     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
         int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8;
-        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex));
-        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex));
+        COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex));
+        #ifdef CHANNEL_BOUNDARY_PROTECT
+        COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #else
+        COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex));
+        #endif
         COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
         COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
         COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6);
         COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7);
         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
-        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
-        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
+        //index:   [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0]
+        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
 
         for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
             const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4;
@@ -1309,10 +1341,17 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
                 out3 = mad(in3.w, weight3, out3);
                 
 #if (defined USE_LOW_BIT_WEIGHT_INT8)
+                #ifdef CHANNEL_BOUNDARY_PROTECT
+                charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset);
+                charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
+                charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
+                charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #else
                 charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset);
                 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);
                 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
                 charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
+                #endif
                 weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1;
                 weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1;
                 weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1;
@@ -1396,7 +1435,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
     out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6);
 #endif
 
-    int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
 #ifdef BLOCK_LEAVE
     const int remain = out_hw.y - out_w_idx;
     if(remain >= 4){
@@ -1409,10 +1448,10 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
     }else if(remain == 1){
         vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset);
     }
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks)return;
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks)return;
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     if(remain >= 4){
         vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset);
     }else if(remain == 3){
@@ -1425,10 +1464,10 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
     }
 #else
     vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output+out_offset);
-#ifdef CHANNEL_LEAVE
-    if(out_c_idx + 1 >= out_c_blocks)return;
+#ifdef CHANNEL_BOUNDARY_PROTECT
+    if(out_c_idx_1 >= out_c_blocks)return;
 #endif
-    out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
+    out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
     vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset);
 #endif
 }

+ 14 - 6
source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl

@@ -303,14 +303,18 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
             COMPUTE_FLOAT4 inValue2 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c0));
             COMPUTE_FLOAT4 inValue3 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c0));
 
-            COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1));
-            COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1));
-            COMPUTE_FLOAT4 inValue6 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c1));
-            COMPUTE_FLOAT4 inValue7 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue6 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue7 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c1));
             
             //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
             //index: [0, filterIdx,                   0, inChannelBlockIdx]
             COMPUTE_FLOAT4 weights_0 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4));
+            /*
+              weight:[kh*kw, oc/4, oc_4], memory align to 8
+              no need to boundry protect
+              */
             COMPUTE_FLOAT4 weights_1 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4));
 
             outValue0 = mad(inValue0, weights_0, outValue0);
@@ -435,12 +439,16 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
             COMPUTE_FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c0));
             COMPUTE_FLOAT4 inValue1 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c0));
 
-            COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1));
-            COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1));
+            COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1));
 
             //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
             //index: [0, filterIdx,                   0, inChannelBlockIdx]
             COMPUTE_FLOAT4 weights_0 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4));
+            /*
+              weight:[kh*kw, oc/4, oc_4], memory align to 8
+              no need to boundry protect
+              */
             COMPUTE_FLOAT4 weights_1 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4));
 
             outValue0 = mad(inValue0, weights_0, outValue0);

+ 211 - 0
source/backend/opencl/execution/cl/glmem_convert.cl

@@ -0,0 +1,211 @@
+#ifdef MNN_SUPPORT_FP16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif
+
+#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2,
+#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3)                       \
+    if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \
+        return;                                                     \
+    }
+
+#define MNN_DATA_FORMAT_NCHW 0
+#define MNN_DATA_FORMAT_NHWC 1
+#define MNN_DATA_FORMAT_NC4HW4 2
+#define MNN_DATA_FORMAT_C4NHW4 3
+
+#define __CAT(x, y) x##y
+#define CAT(x, y) __CAT(x, y)
+#define OUTPUT_TYPE2 CAT(OUTPUT_TYPE, 2)
+#define OUTPUT_TYPE3 CAT(OUTPUT_TYPE, 3)
+__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
+
+#ifdef SHARED_TO_CL
+__kernel void gl_to_cl(GLOBAL_SIZE_3_DIMS
+                                    __global uchar *input_ptr,
+                                    #ifdef USE_IMAGE
+                                    __write_only image2d_t output_ptr,
+                                    #else
+                                    __global OUTPUT_TYPE *output_ptr,
+                                    #endif
+                                    __private const int4 shape // N C H W
+) {
+
+    int wblock  = get_global_id(0);
+    int cblock = get_global_id(1);
+    int nh = get_global_id(2);
+
+    DEAL_NON_UNIFORM_DIM3(wblock, cblock, nh);
+    const int w = wblock << 2;
+    const int h = nh % shape.z;
+    const int c = cblock << 2;
+    const int n = nh / shape.z;
+    
+    int idx = c * shape.w + w;    // c/4*w
+    int idy = nh;    // n*h
+    const int offset = idy * shape.w * 4;
+    OUTPUT_TYPE4 in0 = CONVERT_OUTPUT4(vload4(idx, input_ptr + offset));
+    OUTPUT_TYPE4 in1 = CONVERT_OUTPUT4(vload4(idx + 1, input_ptr + offset));
+    OUTPUT_TYPE4 in2 = CONVERT_OUTPUT4(vload4(idx + 2, input_ptr + offset));
+    OUTPUT_TYPE4 in3 = CONVERT_OUTPUT4(vload4(idx + 3, input_ptr + offset));
+
+#ifdef USE_IMAGE
+    WI_DATA(output_ptr, (int2)(idx, idy), in0);
+    if(w + 1 >= shape.w) return;
+    WI_DATA(output_ptr, (int2)(idx+1, idy), in1);
+    if(w + 2 >= shape.w) return;
+    WI_DATA(output_ptr, (int2)(idx+2, idy), in2);
+    if(w + 3 >= shape.w) return;
+    WI_DATA(output_ptr, (int2)(idx+3, idy), in3);
+#else
+    #if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW
+    int output_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
+    int stride = shape.z * shape.w;
+    int remain = shape.w - w;
+    if(remain >= 4){
+        vstore4((OUTPUT_TYPE4)(in0.x, in1.x, in2.x, in3.x), 0, output_ptr + output_offset);
+        if(c + 1 >= shape.y) return;
+        vstore4((OUTPUT_TYPE4)(in0.y, in1.y, in2.y, in3.y), 0, output_ptr + output_offset + stride);
+        if(c + 2 >= shape.y) return;
+        vstore4((OUTPUT_TYPE4)(in0.z, in1.z, in2.z, in3.z), 0, output_ptr + output_offset + stride + stride);
+        if(c + 3 >= shape.y) return;
+        vstore4((OUTPUT_TYPE4)(in0.w, in1.w, in2.w, in3.w), 0, output_ptr + output_offset + stride + stride + stride);
+    } else if(remain == 3){
+        vstore3((OUTPUT_TYPE3)(in0.x, in1.x, in2.x), 0, output_ptr + output_offset);
+        if(c + 1 >= shape.y) return;
+        vstore3((OUTPUT_TYPE3)(in0.y, in1.y, in2.y), 0, output_ptr + output_offset + stride);
+        if(c + 2 >= shape.y) return;
+        vstore3((OUTPUT_TYPE3)(in0.z, in1.z, in2.z), 0, output_ptr + output_offset + stride + stride);
+        if(c + 3 >= shape.y) return;
+        vstore3((OUTPUT_TYPE3)(in0.w, in1.w, in2.w), 0, output_ptr + output_offset + stride + stride + stride);
+    } else if(remain == 2){
+        vstore2((OUTPUT_TYPE2)(in0.x, in1.x), 0, output_ptr + output_offset);
+        if(c + 1 >= shape.y) return;
+        vstore2((OUTPUT_TYPE2)(in0.y, in1.y), 0, output_ptr + output_offset + stride);
+        if(c + 2 >= shape.y) return;
+        vstore2((OUTPUT_TYPE2)(in0.z, in1.z), 0, output_ptr + output_offset + stride + stride);
+        if(c + 3 >= shape.y) return;
+        vstore2((OUTPUT_TYPE2)(in0.w, in1.w), 0, output_ptr + output_offset + stride + stride + stride);
+    }else if(remain == 1){
+        output_ptr[output_offset] = in0.x;
+        if(c + 1 >= shape.y) return;
+        output_ptr[output_offset + stride] = in0.y;
+        if(c + 2 >= shape.y) return;
+        output_ptr[output_offset + stride + stride] = in0.z;
+        if(c + 3 >= shape.y) return;
+        output_ptr[output_offset + stride + stride + stride] = in0.w;
+    }
+    #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC
+    int output_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
+    int remain = shape.y - c;
+    if(remain >= 4){
+        vstore4(CONVERT_OUTPUT4(in0), 0, output_ptr + output_offset);
+        if(w + 1 >= shape.w) return;
+        vstore4(CONVERT_OUTPUT4(in1), 0, output_ptr + output_offset + shape.y);
+        if(w + 2 >= shape.w) return;
+        vstore4(CONVERT_OUTPUT4(in2), 0, output_ptr + output_offset + shape.y + shape.y);
+        if(w + 3 >= shape.w) return;
+        vstore4(CONVERT_OUTPUT4(in3), 0, output_ptr + output_offset + shape.y + shape.y + shape.y);
+    } else if(remain == 3){
+        vstore3((OUTPUT_TYPE3)(in0.x, in0.y, in0.z), 0, output_ptr + output_offset);
+        if(w + 1 >= shape.w) return;
+        vstore3((OUTPUT_TYPE3)(in1.x, in1.y, in1.z), 0, output_ptr + output_offset + shape.y);
+        if(w + 2 >= shape.w) return;
+        vstore3((OUTPUT_TYPE3)(in2.x, in2.y, in2.z), 0, output_ptr + output_offset + shape.y + shape.y);
+        if(w + 3 >= shape.w) return;
+        vstore3((OUTPUT_TYPE3)(in3.x, in3.y, in3.z), 0, output_ptr + output_offset + shape.y + shape.y + shape.y);
+    } else if(remain == 2){
+        vstore2((OUTPUT_TYPE2)(in0.x, in0.y), 0, output_ptr + output_offset);
+        if(w + 1 >= shape.w) return;
+        vstore2((OUTPUT_TYPE2)(in1.x, in1.y), 0, output_ptr + output_offset + shape.y);
+        if(w + 2 >= shape.w) return;
+        vstore2((OUTPUT_TYPE2)(in2.x, in2.y), 0, output_ptr + output_offset + shape.y + shape.y);
+        if(w + 3 >= shape.w) return;
+        vstore2((OUTPUT_TYPE2)(in3.x, in3.y), 0, output_ptr + output_offset + shape.y + shape.y + shape.y);
+    }else if(remain == 1){
+        output_ptr[output_offset] = in0.x;
+        if(w + 1 >= shape.w) return;
+        output_ptr[output_offset + shape.y] = in1.x;
+        if(w + 2 >= shape.w) return;
+        output_ptr[output_offset + shape.y + shape.y] = in1.x;
+        if(w + 3 >= shape.w) return;
+        output_ptr[output_offset + shape.y + shape.y + shape.y] = in1.x;
+    }
+    #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
+    int output_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4;
+    vstore4(in0, 0, output_ptr + output_offset);
+    if(w + 1 >= shape.w) return;
+    vstore4(in1, 0, output_ptr + output_offset + 4);
+    if(w + 2 >= shape.w) return;
+    vstore4(in2, 0, output_ptr + output_offset + 8);
+    if(w + 3 >= shape.w) return;
+    vstore4(in3, 0, output_ptr + output_offset + 12);
+    #endif
+#endif
+}
+#endif
+
+#ifdef CL_TO_SHARED
+__kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS
+                                    #ifdef USE_IMAGE
+                                    __read_only image2d_t input_ptr,
+                                    #else
+                                    __global INPUT_TYPE *input_ptr,
+                                    #endif
+                                    __global uchar *output_ptr,
+                                    __private const int4 shape // N C H W
+) {
+
+    int wblock  = get_global_id(0);
+    int cblock = get_global_id(1);
+    int nh = get_global_id(2);
+
+    DEAL_NON_UNIFORM_DIM3(wblock, cblock, nh);
+    const int w = wblock << 2;
+    const int h = nh % shape.z;
+    const int c = cblock << 2;
+    const int n = nh / shape.z;
+    
+    int idx = c * shape.w + w;    // c/4*w
+    int idy = nh;    // n*h
+#ifdef USE_IMAGE
+    INPUT_TYPE4 in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy));
+    INPUT_TYPE4 in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy));
+    INPUT_TYPE4 in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy));
+    INPUT_TYPE4 in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy));
+#else
+    #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW
+    int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
+    int stride = shape.z * shape.w;
+    INPUT_TYPE4 tmp0, tmp1, tmp2, tmp3;
+    tmp0 = vload4(0, input_ptr + input_offset);
+    tmp1 = vload4(0, input_ptr + input_offset + stride);
+    tmp2 = vload4(0, input_ptr + input_offset + stride + stride);
+    tmp3 = vload4(0, input_ptr + input_offset + stride + stride + stride);
+    INPUT_TYPE4 in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x);
+    INPUT_TYPE4 in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y);
+    INPUT_TYPE4 in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z);
+    INPUT_TYPE4 in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w);
+    #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC
+    int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
+    INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset);
+    INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + shape.y);
+    INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y);
+    INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y);
+    #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
+    int input_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4;
+    INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset);
+    INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + 4);
+    INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + 8);
+    INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + 12);
+    #endif
+#endif
+    const int offset = idy * shape.w * 4;
+    vstore4(convert_uchar4(in0), idx, output_ptr + offset);
+    if(w + 1 >= shape.w) return;
+    vstore4(convert_uchar4(in1), idx+1, output_ptr + offset);
+    if(w + 2 >= shape.w) return;
+    vstore4(convert_uchar4(in2), idx+2, output_ptr + offset);
+    if(w + 3 >= shape.w) return;
+    vstore4(convert_uchar4(in3), idx+3, output_ptr + offset);
+}
+#endif

File diff suppressed because it is too large
+ 432 - 108
source/backend/opencl/execution/cl/opencl_program.cc


+ 2 - 0
source/backend/opencl/execution/cl/opencl_source_map.hpp

@@ -71,6 +71,7 @@ extern const char* unary_buf;
 #ifndef MNN_OPENCL_BUFFER_CLOSED
 extern const char* depthwise_conv2d_buf;
 #endif
+extern const char* glmem_convert;
 #ifndef MNN_OPENCL_BUFFER_CLOSED
 extern const char* winogradTransform_buf;
 #endif
@@ -242,6 +243,7 @@ const std::map<std::string, const char*> OpenCLProgramMap =
 #ifndef MNN_OPENCL_BUFFER_CLOSED
   { "depthwise_conv2d_buf", depthwise_conv2d_buf },
 #endif
+  { "glmem_convert", glmem_convert },
 #ifndef MNN_OPENCL_BUFFER_CLOSED
   { "winogradTransform_buf", winogradTransform_buf },
 #endif

+ 22 - 6
source/backend/opencl/execution/image/ConvExecution.cpp

@@ -25,12 +25,12 @@ ConvCommonExecution::ConvCommonExecution(const Convolution2D *conv2dParams, Back
     int biasSize             = conv2dParams->bias()->size();
     const float *biasDataPtr = conv2dParams->bias()->data();
     
-    int buffer_size = ALIGN_UP4(biasSize) * sizeof(float);
+    int buffer_size = ALIGN_UP8(biasSize) * sizeof(float);
     cl::Buffer biasBuffer(runtime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
     cl_int error;
     auto biasPtrCL = runtime->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
     if(biasPtrCL != nullptr && error == CL_SUCCESS){
-        ::memset(biasPtrCL, 0, ALIGN_UP4(biasSize) * sizeof(float));
+        ::memset(biasPtrCL, 0, ALIGN_UP8(biasSize) * sizeof(float));
         ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float));
     }else{
         MNN_ERROR("Map error biasPtrCL == nullptr \n");
@@ -328,7 +328,11 @@ ErrorCode ConvExecution::onEncode(const std::vector<Tensor *> &inputs, const std
             std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
             
             for(int knl_idx = 0; knl_idx < 1; knl_idx++) {
-                kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mResource->mBuildOptions);
+                std::set<std::string> buildOption = mResource->mBuildOptions;
+                if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+                    buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+                }
+                kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption);
                 uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
                 
                 globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
@@ -363,7 +367,11 @@ ErrorCode ConvExecution::onEncode(const std::vector<Tensor *> &inputs, const std
             int min_index  = min_cost.second;
             //printf("min_index = %d  %d\n", min_index, min_cost.first);
             mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
-            unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mResource->mBuildOptions);
+            std::set<std::string> buildOption = mResource->mBuildOptions;
+            if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+                buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+            }
+            unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption);
             
             uint32_t idx = 0;
             unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
@@ -407,7 +415,11 @@ ErrorCode ConvExecution::onEncode(const std::vector<Tensor *> &inputs, const std
         std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
         
         for(int knl_idx = 0; knl_idx < total_kernel; knl_idx++) {
-            kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mResource->mBuildOptions);
+            std::set<std::string> buildOption = mResource->mBuildOptions;
+            if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+                buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+            }
+            kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption);
             uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
             
             globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
@@ -446,7 +458,11 @@ ErrorCode ConvExecution::onEncode(const std::vector<Tensor *> &inputs, const std
         }
         int min_index  = min_cost.second;
         mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
-        unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mResource->mBuildOptions);
+        std::set<std::string> buildOption = mResource->mBuildOptions;
+        if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+            buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+        }
+        unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption);
         
         uint32_t idx            = 0;
         cl_int ret = CL_SUCCESS;

+ 12 - 0
source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp

@@ -239,6 +239,9 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu
         if(inputChannels % 4 != 0){
             buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
         }
+        if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+            buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+        }
         kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption);
         uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
         
@@ -277,6 +280,9 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu
     if(inputChannels % 4 != 0){
         buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
     }
+    if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+        buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+    }
     unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption);
     uint32_t idx = 0;
     ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
@@ -338,6 +344,9 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o
         if(inputChannels % 4 != 0){
             buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
         }
+        if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
+            buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+        }
         kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption);
         uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
 
@@ -379,6 +388,9 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o
     if(inputChannels % 4 != 0){
         buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
     }
+    if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
+        buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
+    }
     unit.kernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption);
 
     uint32_t idx            = 0;

+ 1 - 1
source/backend/vulkan/buffer/execution/VulkanPRelu.cpp

@@ -80,7 +80,7 @@ class VulkanReluCreator : public VulkanBackend::Creator {
 public:
     virtual VulkanBasicExecution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor*>& outputs, const MNN::Op *op, Backend *bn) const override {
         if (1 == op->main_as_PRelu()->slopeCount()) {
-            return new VulkanUnary("RELU", bn, op->main_as_PRelu()->slope()->data()[0]);
+            return new VulkanUnary("RELU", bn, false, op->main_as_PRelu()->slope()->data()[0]);
         }
         return new VulkanPrelu(bn, op);
     }

+ 3 - 0
source/core/Interpreter.cpp

@@ -193,6 +193,9 @@ void Interpreter::setExternalFile(const char* file, size_t flag) {
 }
 
 ErrorCode Interpreter::updateCacheFile(Session *session, int flag) {
+    if (mNet->cacheFile.empty()) {
+        return NOT_SUPPORT;
+    }
     std::lock_guard<std::mutex> _l(mNet->lock);
 
     // Backend_Auto and no Async work, then don't need updateCache

+ 1 - 1
source/core/Pipeline.cpp

@@ -27,7 +27,7 @@ static bool _supportQuant(const Op* op, const std::vector<Tensor*>& inputs, cons
     switch (otype) {
         case OpType_Convolution:
         case OpType_ConvolutionDepthwise:
-        case OpType_Deconvolution:
+//        case OpType_Deconvolution:
             if (inputs.size() > 1) {
                 return false;
             }

+ 11 - 1
source/core/TensorUtils.cpp

@@ -487,7 +487,7 @@ static bool _ClipDst(int* stride, int srcOffset, int dstOffset, const int* srcSi
      dx=sx-xo -> [max(0, -xo), max(0, min(sxr-xo, dxr))]
      dy,dz compute the same
      **/
-    
+
     int offsetBias = dstOffset - srcOffset;
     if (sizeNum == 0) {
         // All stride is zero, then size will be all one
@@ -903,4 +903,14 @@ void TensorUtils::setTensorPad(const Tensor* tensor, int left, int right, int bo
     srcDes->mPads.top = std::max(srcDes->mPads.top, top);
 }
 
+void TensorUtils::setSharedMem(const Tensor *tensor, Backend::MemObj *mem){
+    auto srcDes = TensorUtils::getDescribe(tensor);
+    srcDes->mSharedMem = mem;
+}
+
+Backend::MemObj* TensorUtils::getSharedMem(const Tensor* tensor){
+    auto srcDes = TensorUtils::getDescribe(tensor);
+    return srcDes->mSharedMem.get();
+}
+
 } // namespace MNN

+ 6 - 0
source/core/TensorUtils.hpp

@@ -124,6 +124,8 @@ struct Tensor::InsideDescribe {
         pad mPads;
         // For isMutable = false Tensor , determine whether the content can be convert to main backend
         uint32_t stageMask = 0;
+        // Use for shared memory
+        SharedPtr<Backend::MemObj> mSharedMem;
     };
     std::shared_ptr<NativeInsideDescribe> mContent;
     SharedPtr<Backend::MemObj> mem;
@@ -224,6 +226,10 @@ public:
     static void setTensorSupportPack(const Tensor* tensor, bool flag);
 
     static void setTensorPad(const Tensor* tensor, int left, int right, int bottom, int top);
+    
+    static void setSharedMem(const Tensor* tensor, Backend::MemObj *mem);
+    
+    static Backend::MemObj* getSharedMem(const Tensor* tensor);
 };
 } // namespace MNN
 

+ 1 - 1
source/shape/ShapeConcat.cpp

@@ -14,7 +14,7 @@ class ConcatSizeComputer : public SizeComputer {
     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
                                const std::vector<Tensor*>& outputs) const override {
         MNN_ASSERT(1 == outputs.size());
-        MNN_ASSERT(inputs.size() >= 2);
+        // MNN_ASSERT(inputs.size() >= 2);
         auto& ob      = outputs[0]->buffer();
         int basicAxis = 0;
         if (op->type() == OpType_Concat) {

+ 6 - 0
source/shape/ShapeRegister.cpp

@@ -122,6 +122,9 @@ extern void ___FmhaV2SizeComputer__OpType_FmhaV2__();
 extern void ___FmhcaSizeComputer__OpType_Fmhca__();
 extern void ___AttentionSizeComputer__OpType_Attention__();
 #endif
+#ifdef MNN_BUILD_AUDIO
+extern void ___StftOpComputer__OpType_Stft__();
+#endif
 void registerShapeOps() {
 ___ShapeSizeComputer__OpType_Shape__();
 ___ShapeRasterComputer__OpType_Raster__();
@@ -244,5 +247,8 @@ ___FmhaV2SizeComputer__OpType_FmhaV2__();
 ___FmhcaSizeComputer__OpType_Fmhca__();
 ___AttentionSizeComputer__OpType_Attention__();
 #endif
+#ifdef MNN_BUILD_AUDIO
+___StftOpComputer__OpType_Stft__();
+#endif
 }
 }

+ 38 - 0
source/shape/ShapeStft.cpp

@@ -0,0 +1,38 @@
+//
+//  ShapeStft.cpp
+//  MNN
+//
+//  Created by MNN on 2024/11/26.
+//  Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef MNN_BUILD_AUDIO
+
+#include "shape/SizeComputer.hpp"
+#include "core/Macro.h"
+#include "core/TensorUtils.hpp"
+
+namespace MNN {
+
+class StftOpComputer : public SizeComputer {
+    virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
+                               const std::vector<Tensor*>& outputs) const override {
+        int sample_length = inputs[0]->elementSize();
+        auto stft = op->main_as_StftParam();
+        bool abs = stft->abs();
+        int n_fft = stft->n_fft();
+        int hop_length = stft->hop_length();
+        int frames = (sample_length - n_fft) / hop_length + 1;
+        // Scalar
+        outputs[0]->buffer().dimensions = 2;
+        outputs[0]->setLength(0, frames);
+        outputs[0]->setLength(1, n_fft / 2 + 1);
+        outputs[0]->buffer().type = inputs[0]->getType();
+        TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
+        return true;
+    }
+};
+
+REGISTER_SHAPE_AUDIO(StftOpComputer, OpType_Stft);
+} // namespace MNN
+#endif // MNN_BUILD_AUDIO

+ 9 - 0
source/shape/SizeComputer.hpp

@@ -186,4 +186,13 @@ public:
 
 #endif
 
+#ifdef MNN_BUILD_AUDIO
+#define REGISTER_SHAPE_AUDIO(name, op)            \
+    void ___##name##__##op##__() {                        \
+        name* _temp = new name;                            \
+        SizeComputerSuite* ts = SizeComputerSuite::get(); \
+        ts->insert(_temp, op);                           \
+    }
+#endif
+
 #endif

+ 2 - 2
test.sh

@@ -167,7 +167,7 @@ android_static_build() {
     -DMNN_INTERNAL=ON \
     -DMNN_USE_LOGCAT=false \
     -DMNN_BUILD_BENCHMARK=ON \
-    -DANDROID_NATIVE_API_LEVEL=android-21  \
+    -DANDROID_NATIVE_API_LEVEL=android-26  \
     -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
     -DMNN_OPENGL=true \
     -DMNN_BUILD_TRAIN=true \
@@ -198,7 +198,7 @@ android_static_build() {
     -DMNN_USE_LOGCAT=false \
     -DMNN_BUILD_BENCHMARK=ON \
     -DMNN_INTERNAL=ON \
-    -DANDROID_NATIVE_API_LEVEL=android-21  \
+    -DANDROID_NATIVE_API_LEVEL=android-26  \
     -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
     -DMNN_OPENGL=true \
     -DMNN_BUILD_TRAIN=true \

+ 0 - 0
test/CMakeLists.txt


Some files were not shown because too many files changed in this diff