MNNGemmInt8AddBiasScale_16x4_Unit.S 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. //
  2. // MNNGemmInt8AddBiasScale_16x4_Unit.S
  3. // MNN
  4. //
  5. // Created by MNN on 2019/06/11.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #ifdef __arm__
  9. #ifndef __aarch64__
  10. #include "MNNAsmGlobal.h"
  11. .text
  12. .align 5
  13. asm_function MNNGemmInt8AddBiasScale_16x4_Unit
  14. /*
  15. struct QuanPostTreatParameters {
  16. const float* scale;
  17. const float* biasFloat;
  18. int32_t maxValue;
  19. int32_t minValue;
  20. int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32.
  21. float roundValuePos = 0.5f;
  22. float roundValueNeg = -0.5f;
  23. float* srcKernelSum;
  24. float* weightQuanBias;
  25. float* fp32minmax;
  26. ssize_t blockNum = 1;
  27. const int32_t* bias;
  28. const float* extraScale = nullptr;
  29. const float* extraBias = nullptr;
  30. };
  31. */
  32. //void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
  33. // size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t real) {
  34. //Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad
  35. // Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real
  36. // Load from post: r8: scale, lr: bias, r7: maxValue, r6: minValue
  37. push {r4-r8, r10, lr} // avoid to touch platform-register r-9
  38. ldr r4, [sp, #28]
  39. ldr r5, [sp, #32]
  40. ldr r6, [sp, #36]
  41. ldr r10, [sp, #40]
  42. ldr r8, [r6, #0]
  43. ldr lr, [r6, #4]
  44. vpush {q4-q7}
  45. sub sp, sp, #36
  46. ldr r7, [r6, #16] // r7: useInt8
  47. ldr r12, [r6, #28] // srcKernelSum
  48. str r12, [sp, #4]
  49. ldr r12, [r6, #32] // weightBias
  50. str r12, [sp, #8]
  51. ldr r12, [r6, #36] // f32minmax
  52. str r12, [sp, #12]
  53. ldr r12, [r6, #8] // int8 max
  54. str r12, [sp, #16]
  55. ldr r12, [r6, #12] // int8 min
  56. str r12, [sp, #20]
  57. ldr r12, [r6, #40] // blockNum
  58. mul r12, r12, r3 // src_depth_quad=src_depth_quad*blockNum
  59. lsl r12, r12, #6 // weight_stride = src_depth_quad*LP*HP
  60. str r12, [sp, #24]
  61. ldr r12, [r6, #48] // extraScale
  62. str r12, [sp, #28]
  63. Start:
  64. cmp r10, #2
  65. blt L1LoopDz
  66. L2LoopDz:
  67. mov r10, r1
  68. str r2, [sp, #32] // store weight ptr
  69. subs r12, r3, #1
  70. // first four output
  71. vld1.8 {q2}, [r1]!
  72. vld1.8 {q4,q5}, [r2]!
  73. vmull.s8 q0, d4, d8
  74. vmull.s8 q1, d4, d10
  75. vmlal.s8 q0, d5, d9
  76. vmlal.s8 q1, d5, d11
  77. vpaddl.s16 q8, q0
  78. vpaddl.s16 q9, q1
  79. vld1.8 {q6,q7}, [r2]!
  80. vmull.s8 q0, d4, d12
  81. vmull.s8 q1, d4, d14
  82. vmlal.s8 q0, d5, d13
  83. vmlal.s8 q1, d5, d15
  84. vpaddl.s16 q10, q0
  85. vld1.8 {q3}, [r1]!
  86. vpaddl.s16 q11, q1
  87. // second four output
  88. vmull.s8 q0, d6, d8
  89. vmull.s8 q1, d6, d10
  90. vmlal.s8 q0, d7, d9
  91. vmlal.s8 q1, d7, d11
  92. vpaddl.s16 q12, q0
  93. vpaddl.s16 q13, q1
  94. vmull.s8 q0, d6, d12
  95. vmull.s8 q1, d6, d14
  96. vmlal.s8 q0, d7, d13
  97. vmlal.s8 q1, d7, d15
  98. vpaddl.s16 q14, q0
  99. vpaddl.s16 q15, q1
  100. beq L2LoopSzEnd
  101. L2LoopSz:
  102. // first four output
  103. vld1.8 {q2}, [r1]!
  104. vld1.8 {q4,q5}, [r2]!
  105. vmull.s8 q0, d4, d8
  106. vmull.s8 q1, d4, d10
  107. vmlal.s8 q0, d5, d9
  108. vmlal.s8 q1, d5, d11
  109. vld1.8 {q6,q7}, [r2]!
  110. vpadal.s16 q8, q0
  111. vpadal.s16 q9, q1
  112. vmull.s8 q0, d4, d12
  113. vmull.s8 q1, d4, d14
  114. vmlal.s8 q0, d5, d13
  115. vmlal.s8 q1, d5, d15
  116. vld1.8 {q3}, [r1]!
  117. vpadal.s16 q10, q0
  118. vpadal.s16 q11, q1
  119. // second four output
  120. vmull.s8 q0, d6, d8
  121. vmull.s8 q1, d6, d10
  122. vmlal.s8 q0, d7, d9
  123. vmlal.s8 q1, d7, d11
  124. vpadal.s16 q12, q0
  125. vpadal.s16 q13, q1
  126. vmull.s8 q0, d6, d12
  127. vmull.s8 q1, d6, d14
  128. vmlal.s8 q0, d7, d13
  129. vmlal.s8 q1, d7, d15
  130. vpadal.s16 q14, q0
  131. vpadal.s16 q15, q1
  132. subs r12, r12, #1
  133. bne L2LoopSz
  134. L2LoopSzEnd:
  135. L2Quan:
  136. vld1.f32 {q5}, [r8]! // scale
  137. vpadd.s32 d16, d16, d17
  138. vpadd.s32 d20, d20, d21
  139. vpadd.s32 d18, d18, d19
  140. vpadd.s32 d22, d22, d23
  141. vpadd.s32 d24, d24, d25
  142. vpadd.s32 d28, d28, d29
  143. vpadd.s32 d26, d26, d27
  144. vpadd.s32 d30, d30, d31
  145. // q8,q9
  146. vpadd.s32 d16, d16, d18
  147. vpadd.s32 d17, d20, d22
  148. vpadd.s32 d18, d24, d26
  149. vpadd.s32 d19, d28, d30
  150. // vaddq.s32 q0, q8, q4 // add bias
  151. // vaddq.s32 q1, q9, q4
  152. vcvt.f32.s32 q0, q8
  153. vcvt.f32.s32 q1, q9
  154. vmulq.f32 q0, q0, q5 // mul scale
  155. vmulq.f32 q1, q1, q5
  156. // extra scale if has
  157. ldr r6, [sp, #28]
  158. cmp r6, #0
  159. beq L2_MLA
  160. vld1.f32 {d10[0]}, [r6]! // tile0
  161. vld1.f32 {d10[1]}, [r6] // tile1
  162. vmulq.f32 q0, q0, d10[0]
  163. vmulq.f32 q1, q1, d10[1]
  164. L2_MLA:
  165. ldr r6, [sp, #4] // srcKernelSum
  166. vld1.f32 {d12[0]}, [r6]! // tile 0
  167. vld1.f32 {d12[1]}, [r6] // tile 1
  168. ldr r6, [sp, #8] // weightBias
  169. vld1.f32 {q7}, [r6]!
  170. str r6, [sp, #8] // update next 4 weightBias
  171. vmla.f32 q0, q7, d12[0]
  172. vmla.f32 q1, q7, d12[1]
  173. cmp r7, #0
  174. bne L2QuanUseInt8
  175. L2_ADD_BIAS:
  176. cmp lr, #0
  177. beq L2_ADD_DSTV
  178. vld1.f32 {q4}, [lr]! // bias
  179. vadd.f32 q0, q0, q4 // bias
  180. vadd.f32 q1, q1, q4
  181. b L2_POST
  182. L2_ADD_DSTV:
  183. vld1.f32 {q4, q5}, [r0]
  184. vadd.f32 q0, q0, q4
  185. vadd.f32 q1, q1, q5
  186. L2_POST:
  187. ldr r6, [sp, #12] // fp32 minmax
  188. cmp r6, #0
  189. beq L2_STORE
  190. vld1.f32 {d20[0]}, [r6]!
  191. vld1.f32 {d22[0]}, [r6]
  192. vdup.f32 q10, d20[0]
  193. vdup.f32 q11, d22[0]
  194. vmax.f32 q0, q0, q10
  195. vmax.f32 q1, q1, q10
  196. vmin.f32 q0, q0, q11
  197. vmin.f32 q1, q1, q11
  198. L2_STORE:
  199. vst1.f32 {q0, q1}, [r0], r4
  200. b L2LoopCheck
  201. L2QuanUseInt8:
  202. vld1.f32 {q4}, [lr]! // bias
  203. vadd.f32 q0, q0, q4 // bias
  204. vadd.f32 q1, q1, q4
  205. vmov.f32 q10, #0.5
  206. vmov.f32 q11, #-0.5
  207. ldr r6, [sp, #16]
  208. vdup.32 q3, r6 // max
  209. ldr r6, [sp, #20]
  210. vdup.32 q2, r6 // min
  211. vcgt.f32 q12, q0, #0
  212. vcgt.f32 q13, q1, #0
  213. vbsl.f32 q12, q10, q11
  214. vbsl.f32 q13, q10, q11
  215. vadd.f32 q0, q12, q0
  216. vadd.f32 q1, q13, q1
  217. vcvt.s32.f32 q0, q0
  218. vcvt.s32.f32 q1, q1
  219. vmax.s32 q0, q2, q0
  220. vmax.s32 q1, q2, q1
  221. vmin.s32 q0, q3, q0
  222. vmin.s32 q1, q3, q1
  223. vqmovn.s32 d4, q0
  224. vqmovn.s32 d5, q1
  225. vqmovn.s16 d6, q2
  226. vst1.s8 {d6}, [r0], r4
  227. L2LoopCheck:
  228. subs r5, r5, #1
  229. mov r1, r10
  230. ldr r2, [sp, #32] // origin weight ptr
  231. ldr r6, [sp, #24] // weight stride
  232. add r2, r2, r6 // next oc4 weight ptr
  233. bne L2LoopDz
  234. b End
  235. L1LoopDz:
  236. mov r10, r1
  237. str r2, [sp, #32] // store weight ptr
  238. subs r12, r3, #1
  239. // first four output
  240. vld1.8 {q2}, [r1]!
  241. vld1.8 {q4,q5}, [r2]!
  242. vmull.s8 q0, d4, d8
  243. vmull.s8 q1, d4, d10
  244. vmlal.s8 q0, d5, d9
  245. vmlal.s8 q1, d5, d11
  246. vpaddl.s16 q8, q0
  247. vpaddl.s16 q9, q1
  248. vld1.8 {q6,q7}, [r2]!
  249. vmull.s8 q0, d4, d12
  250. vmull.s8 q1, d4, d14
  251. vmlal.s8 q0, d5, d13
  252. vmlal.s8 q1, d5, d15
  253. vpaddl.s16 q10, q0
  254. add r1, r1, #16
  255. vpaddl.s16 q11, q1
  256. beq L1LoopSzEnd
  257. L1LoopSz:
  258. // first four output
  259. vld1.8 {q2}, [r1]!
  260. vld1.8 {q4,q5}, [r2]!
  261. vmull.s8 q0, d4, d8
  262. vmull.s8 q1, d4, d10
  263. vmlal.s8 q0, d5, d9
  264. vmlal.s8 q1, d5, d11
  265. vld1.8 {q6,q7}, [r2]!
  266. vpadal.s16 q8, q0
  267. vpadal.s16 q9, q1
  268. vmull.s8 q0, d4, d12
  269. vmull.s8 q1, d4, d14
  270. vmlal.s8 q0, d5, d13
  271. vmlal.s8 q1, d5, d15
  272. add r1, r1, #16
  273. vpadal.s16 q10, q0
  274. vpadal.s16 q11, q1
  275. subs r12, r12, #1
  276. bne L1LoopSz
  277. L1LoopSzEnd:
  278. L1Quan:
  279. //vld1.f32 {q4}, [lr]! // bias
  280. vld1.f32 {q5}, [r8]! // scale
  281. vpadd.s32 d16, d16, d17
  282. vpadd.s32 d20, d20, d21
  283. vpadd.s32 d18, d18, d19
  284. vpadd.s32 d22, d22, d23
  285. // q8
  286. vpadd.s32 d16, d16, d18
  287. vpadd.s32 d17, d20, d22
  288. // vaddq.s32 q0, q8, q4
  289. vcvt.f32.s32 q0, q8
  290. vmulq.f32 q0, q0, q5
  291. // extra scale if has
  292. ldr r6, [sp, #28]
  293. cmp r6, #0
  294. beq L1_MLA
  295. vld1.f32 {d10[0]}, [r6] // tile0
  296. vmulq.f32 q0, q0, d10[0]
  297. L1_MLA:
  298. ldr r6, [sp, #4] // srcKernelSum
  299. vld1.f32 {d12[0]}, [r6] // tile 0
  300. ldr r6, [sp, #8] // weightBias
  301. vld1.f32 {q7}, [r6]!
  302. str r6, [sp, #8] // update next 4 weightBias
  303. vmla.f32 q0, q7, d12[0]
  304. //vadd.f32 q0, q0, q4
  305. cmp r7, #0
  306. bne L1QuanUseInt8
  307. cmp lr, #0
  308. beq L1_ADD_DSTV
  309. vld1.f32 {q4}, [lr]! // bias
  310. vadd.f32 q0, q0, q4
  311. b L1_POST
  312. L1_ADD_DSTV:
  313. vld1.f32 {q4}, [r0]
  314. vadd.f32 q0, q0, q4
  315. L1_POST:
  316. ldr r6, [sp, #12] // fp32 minmax
  317. cmp r6, #0
  318. beq L1_STORE
  319. vld1.f32 {d20[0]}, [r6]!
  320. vld1.f32 {d22[0]}, [r6]
  321. vdup.f32 q10, d20[0]
  322. vdup.f32 q11, d22[0]
  323. vmax.f32 q0, q0, q10
  324. vmin.f32 q0, q0, q11
  325. L1_STORE:
  326. vst1.f32 {q0}, [r0], r4
  327. b L1LoopCheck
  328. L1QuanUseInt8:
  329. vld1.f32 {q4}, [lr]! // bias
  330. vadd.f32 q0, q0, q4
  331. vmov.f32 q10, #0.5
  332. vmov.f32 q11, #-0.5
  333. ldr r6, [sp, #16]
  334. vdup.32 q3, r6 // max
  335. ldr r6, [sp, #20]
  336. vdup.32 q2, r6 // min
  337. vcgt.f32 q12, q0, #0
  338. vbsl.f32 q12, q10, q11
  339. vbsl.f32 q13, q10, q11
  340. vadd.f32 q0, q12, q0
  341. vcvt.s32.f32 q0, q0
  342. vmax.s32 q0, q2, q0
  343. vmin.s32 q0, q3, q0
  344. vqmovn.s32 d4, q0
  345. vqmovn.s16 d6, q2
  346. vst1.s32 {d6[0]}, [r0], r4
  347. L1LoopCheck:
  348. subs r5, r5, #1
  349. mov r1, r10
  350. ldr r2, [sp, #32] // origin weight ptr
  351. ldr r6, [sp, #24] // weight stride
  352. add r2, r2, r6 // next oc4 weight ptr
  353. bne L1LoopDz
  354. End:
  355. add sp, sp, #36
  356. vpop {q4-q7}
  357. pop {r4-r8, r10, pc}
  358. #endif
  359. #endif