MNNPackedSparseQuantMatMulEpx4.S 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. //
  2. // MNNPackedSparseQuantMatMulEpx4.S
  3. // MNN
  4. //
  5. // Created by MNN on 2021/06/23.
  6. // Copyright © 2018-2021 Alibaba Group Holding Limited
  7. //
  8. //
  9. #ifdef __arm__
  10. #ifndef __aarch64__
  11. #include "MNNAsmGlobal.h"
  12. #define sizeof_value 4
  13. #define sizeof_value_lg2 2
  14. #define sparse_blockoc 4
  15. .macro TYPE_CVT op, z0, z1, z2, z3
  16. \op \z0, \z0
  17. \op \z1, \z1
  18. \op \z2, \z2
  19. \op \z3, \z3
  20. .endm
  21. .macro CLAMP op, z0, z1, z2, z3, m0
  22. \op \z0, \z0, \m0
  23. \op \z1, \z1, \m0
  24. \op \z2, \z2, \m0
  25. \op \z3, \z3, \m0
  26. .endm
  27. .macro SCALE z0, z1, z2, z3, scale
  28. vmul.f32 \z0, \z0, \scale
  29. vmul.f32 \z1, \z1, \scale
  30. vmul.f32 \z2, \z2, \scale
  31. vmul.f32 \z3, \z3, \scale
  32. .endm
  33. .macro ROUND_MODE z0, z1, z2, z3
  34. vcgt.f32 q0, \z0, #0
  35. vcgt.f32 q1, \z1, #0
  36. vcgt.f32 q2, \z2, #0
  37. vcgt.f32 q3, \z3, #0
  38. vbsl.f32 q0, q4, q5
  39. vbsl.f32 q1, q4, q5
  40. vbsl.f32 q2, q4, q5
  41. vbsl.f32 q3, q4, q5
  42. vadd.f32 \z0, \z0, q0
  43. vadd.f32 \z1, \z1, q1
  44. vadd.f32 \z2, \z2, q2
  45. vadd.f32 \z3, \z3, q3
  46. .endm
  47. .text
  48. .align 5
  49. // caution!!! this is 8 * 4 Sparse MatMul
  50. asm_function MNNPackedSparseQuantMatMulEpx4
  51. // void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam,
  52. // const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
  53. //Auto load: r0: C, r1:A, r2:B, r3:sparseQuantParam,
  54. //load from stack r4:QuanPostTreatParameters, r5:NNZMap, r6:dataOffsetMap
  55. // var not defined: bias,
  56. push {r4-r8, r10, r11, lr}
  57. vpush {q4-q7}
  58. #define push_registers_bytes (8 * 4 + 4 * 16)
  59. ldr r4, [sp, #push_registers_bytes]
  60. ldr r7, [r4, #8]
  61. ldr r8, [r4, #12]
  62. vmov.f32 q4, #0.5
  63. vmov.f32 q5, #-0.5
  64. vdup.32 q6, r7 // max
  65. vdup.32 q7, r8 // min
  66. // r0: C
  67. // r1: A
  68. // r2: B
  69. // r3: sparseQuantParam mem(6*4byte) [eSize, eP, aStride, l, h, cStride]
  70. // r4: QuanPostTreatParameters mem(4*4byte) [scale, bias, max, min]
  71. // r5: NNZMap
  72. // r6: dataOffsetMap
  73. // r7: scale
  74. // r8: bias
  75. // r10: loop_counter (loop_e8 / loop_e4 / loop_e2 / loop_e1), cStride
  76. // r11: loop_counter (loop_e8h4 / loop_e4h4 / loop_e2h4 / loop_e1h4)
  77. // r12: loop_counter (loop_e8h4l1 / loop_e4h4l1 / loop_e2h4l1 / loop_e1h4l1)
  78. // lr: temp var
  79. ldr r10, [r3]
  80. loop_e8:
  81. cmp r10, #8
  82. blt loop_e4
  83. sub r10, r10, #8
  84. ldr r5, [sp, #(push_registers_bytes + 4)]
  85. ldr r6, [sp, #(push_registers_bytes + 8)]
  86. ldr r7, [r4]
  87. ldr r8, [r4, #4]
  88. push {r0-r2, r10}
  89. ldr r10, [r3, #20] // cStride
  90. ldr lr, [r6], #4 // dataOffset
  91. add r1, r1, lr
  92. ldr r11, [r3, #16] // h
  93. lsr r11, r11, #2 // hDiv4 (C4)
  94. loop_e8h4:
  95. vld1.32 q8, [r8]!
  96. vmov q9, q8
  97. vmov q10, q8
  98. vmov q11, q8
  99. vmov q12, q8
  100. vmov q13, q8
  101. vmov q14, q8
  102. vmov q15, q8
  103. ldr r12, [r5], #4
  104. cmp r12, #0
  105. beq loop_e8h4_end
  106. loop_e8h4l1:
  107. vld1.32 d0[0], [r2]!
  108. vld1.8 d2, [r1]
  109. vmovl.s8 q0, d0
  110. vmovl.s8 q1, d2
  111. ldr lr, [r6], #4
  112. add r1, r1, lr
  113. subs r12, r12, #1
  114. vmlal.s16 q8, d0, d2[0]
  115. vmlal.s16 q9, d0, d2[1]
  116. vmlal.s16 q10, d0, d2[2]
  117. vmlal.s16 q11, d0, d2[3]
  118. vmlal.s16 q12, d0, d3[0]
  119. vmlal.s16 q13, d0, d3[1]
  120. vmlal.s16 q14, d0, d3[2]
  121. vmlal.s16 q15, d0, d3[3]
  122. bne loop_e8h4l1
  123. loop_e8h4_end:
  124. vld1.32 q0, [r7]!
  125. TYPE_CVT vcvt.f32.s32, q8, q9, q10, q11
  126. TYPE_CVT vcvt.f32.s32, q12, q13, q14, q15
  127. SCALE q8, q9, q10, q11, q0
  128. SCALE q12, q13, q14, q15, q0
  129. ROUND_MODE q8, q9, q10, q11
  130. ROUND_MODE q12, q13, q14, q15
  131. TYPE_CVT vcvt.s32.f32, q8, q9, q10, q11
  132. TYPE_CVT vcvt.s32.f32, q12, q13, q14, q15
  133. CLAMP vmin.s32, q8, q9, q10, q11, q6
  134. CLAMP vmin.s32, q12, q13, q14, q15, q6
  135. CLAMP vmax.s32, q8, q9, q10, q11, q7
  136. CLAMP vmax.s32, q12, q13, q14, q15, q7
  137. vqmovn.s32 d0, q8
  138. vqmovn.s32 d1, q9
  139. vqmovn.s32 d2, q10
  140. vqmovn.s32 d3, q11
  141. vqmovn.s32 d4, q12
  142. vqmovn.s32 d5, q13
  143. vqmovn.s32 d6, q14
  144. vqmovn.s32 d7, q15
  145. vqmovn.s16 d0, q0
  146. vqmovn.s16 d1, q1
  147. vqmovn.s16 d2, q2
  148. vqmovn.s16 d3, q3
  149. vst1.8 {q0, q1}, [r0], r10
  150. subs r11, r11, #1
  151. bne loop_e8h4
  152. pop {r0-r2, r10}
  153. add r0, r0, #32
  154. add r1, r1, #8
  155. b loop_e8
  156. loop_e4:
  157. cmp r10, #4
  158. blt loop_e2
  159. sub r10, r10, #4
  160. ldr r5, [sp, #(push_registers_bytes + 4)]
  161. ldr r6, [sp, #(push_registers_bytes + 8)]
  162. ldr r7, [r4]
  163. ldr r8, [r4, #4]
  164. push {r0-r2, r10}
  165. ldr r10, [r3, #20] // cStride
  166. ldr lr, [r6], #4 // dataOffset
  167. add r1, r1, lr
  168. ldr r11, [r3, #16] // h
  169. lsr r11, r11, #2 // hDiv4 (C4)
  170. loop_e4h4:
  171. vld1.32 q8, [r8]!
  172. vmov q9, q8
  173. vmov q10, q8
  174. vmov q11, q8
  175. ldr r12, [r5], #4
  176. cmp r12, #0
  177. beq loop_e4h4_end
  178. loop_e4h4l1:
  179. vld1.32 d0[0], [r2]!
  180. vld1.32 d2[0], [r1]
  181. vmovl.s8 q0, d0
  182. vmovl.s8 q1, d2
  183. ldr lr, [r6], #4
  184. add r1, r1, lr
  185. subs r12, r12, #1
  186. vmlal.s16 q8, d0, d2[0]
  187. vmlal.s16 q9, d0, d2[1]
  188. vmlal.s16 q10, d0, d2[2]
  189. vmlal.s16 q11, d0, d2[3]
  190. bne loop_e4h4l1
  191. loop_e4h4_end:
  192. vld1.32 q0, [r7]!
  193. TYPE_CVT vcvt.f32.s32, q8, q9, q10, q11
  194. SCALE q8, q9, q10, q11, q0
  195. ROUND_MODE q8, q9, q10, q11
  196. TYPE_CVT vcvt.s32.f32, q8, q9, q10, q11
  197. CLAMP vmin.s32, q8, q9, q10, q11, q6
  198. CLAMP vmax.s32, q8, q9, q10, q11, q7
  199. vqmovn.s32 d0, q8
  200. vqmovn.s32 d1, q9
  201. vqmovn.s32 d2, q10
  202. vqmovn.s32 d3, q11
  203. vqmovn.s16 d0, q0
  204. vqmovn.s16 d1, q1
  205. vst1.8 {q0}, [r0], r10
  206. subs r11, r11, #1
  207. bne loop_e4h4
  208. pop {r0-r2, r10}
  209. add r0, r0, #16
  210. add r1, r1, #4
  211. b loop_e4
  212. loop_e2:
  213. cmp r10, #2
  214. blt loop_e1
  215. sub r10, r10, #2
  216. ldr r5, [sp, #(push_registers_bytes + 4)]
  217. ldr r6, [sp, #(push_registers_bytes + 8)]
  218. ldr r7, [r4]
  219. ldr r8, [r4, #4]
  220. push {r0-r2, r10}
  221. ldr r10, [r3, #20] // cStride
  222. ldr lr, [r6], #4 // dataOffset
  223. add r1, r1, lr
  224. ldr r11, [r3, #16] // h
  225. lsr r11, r11, #2 // hDiv4 (C4)
  226. loop_e2h4:
  227. vld1.32 q8, [r8]!
  228. vmov q9, q8
  229. ldr r12, [r5], #4
  230. cmp r12, #0
  231. beq loop_e2h4_end
  232. loop_e2h4l1:
  233. vld1.32 d0[0], [r2]!
  234. vld1.16 d2[0], [r1]
  235. vmovl.s8 q0, d0
  236. vmovl.s8 q1, d2
  237. ldr lr, [r6], #4
  238. add r1, r1, lr
  239. subs r12, r12, #1
  240. vmlal.s16 q8, d0, d2[0]
  241. vmlal.s16 q9, d0, d2[1]
  242. bne loop_e2h4l1
  243. loop_e2h4_end:
  244. vld1.32 q0, [r7]!
  245. vcvt.f32.s32 q8, q8
  246. vcvt.f32.s32 q9, q9
  247. vmul.f32 q8, q8, q0
  248. vmul.f32 q9, q9, q0
  249. vcgt.f32 q1, q8, #0
  250. vcgt.f32 q2, q9, #0
  251. vbsl.f32 q1, q4, q5
  252. vbsl.f32 q2, q4, q5
  253. vadd.f32 q8, q8, q1
  254. vadd.f32 q9, q9, q2
  255. vcvt.s32.f32 q8, q8
  256. vcvt.s32.f32 q9, q9
  257. vmin.s32 q8, q8, q6
  258. vmin.s32 q9, q9, q6
  259. vmax.s32 q8, q8, q7
  260. vmax.s32 q9, q9, q7
  261. vqmovn.s32 d0, q8
  262. vqmovn.s32 d1, q9
  263. vqmovn.s16 d0, q0
  264. vst1.8 {d0}, [r0], r10
  265. subs r11, r11, #1
  266. bne loop_e2h4
  267. pop {r0-r2, r10}
  268. add r0, r0, #8
  269. add r1, r1, #2
  270. b loop_e2
  271. loop_e1:
  272. cmp r10, #1
  273. blt End
  274. sub r10, r10, #1
  275. ldr r5, [sp, #(push_registers_bytes + 4)]
  276. ldr r6, [sp, #(push_registers_bytes + 8)]
  277. ldr r7, [r4]
  278. ldr r8, [r4, #4]
  279. push {r0-r2, r10}
  280. ldr r10, [r3, #20] // cStride
  281. ldr lr, [r6], #4 // dataOffset
  282. add r1, r1, lr
  283. ldr r11, [r3, #16] // h
  284. lsr r11, r11, #2 // hDiv4 (C4)
  285. loop_e1h4:
  286. vld1.32 q8, [r8]!
  287. ldr r12, [r5], #4
  288. cmp r12, #0
  289. beq loop_e1h4_end
  290. loop_e1h4l1:
  291. vld1.32 d0[0], [r2]!
  292. vld1.8 d2[0], [r1]
  293. vmovl.s8 q0, d0
  294. vmovl.s8 q1, d2
  295. ldr lr, [r6], #4
  296. add r1, r1, lr
  297. subs r12, r12, #1
  298. vmlal.s16 q8, d0, d2[0]
  299. bne loop_e1h4l1
  300. loop_e1h4_end:
  301. vld1.32 q0, [r7]!
  302. vcvt.f32.s32 q8, q8
  303. vmul.f32 q8, q8, q0
  304. vcgt.f32 q1, q8, #0
  305. vbsl.f32 q1, q4, q5
  306. vadd.f32 q8, q8, q1
  307. vcvt.s32.f32 q8, q8
  308. vmin.s32 q8, q8, q6
  309. vmax.s32 q8, q8, q7
  310. vqmovn.s32 d0, q8
  311. vqmovn.s16 d0, q0
  312. vst1.32 {d0[0]}, [r0], r10
  313. subs r11, r11, #1
  314. bne loop_e1h4
  315. pop {r0-r2, r10}
  316. add r0, r0, #4
  317. add r1, r1, #1
  318. b loop_e1
  319. End:
  320. vpop {q4-q7}
  321. pop {r4-r8, r10, r11, pc}
  322. #undef push_registers_bytes
  323. #undef sizeof_value
  324. #undef sizeof_value_lg2
  325. #undef sparse_blockoc
  326. #endif
  327. #endif