arm_fully_connected_q7_opt.c 15 KB


  1. /*
  2. * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_fully_connected_q7_opt.c
  21. * Description: Q7 basic fully-connected layer function
  22. *
  23. * $Date: 17. January 2018
  24. * $Revision: V.1.0.0
  25. *
  26. * Target Processor: Cortex-M cores
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_math.h"
  30. #include "arm_nnfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup FC
  36. * @{
  37. */
  38. /**
  39. * @brief Q7 opt fully-connected layer function
  40. * @param[in] pV pointer to input vector
  41. * @param[in] pM pointer to matrix weights
  42. * @param[in] dim_vec length of the vector
  43. * @param[in] num_of_rows number of rows in weight matrix
  44. * @param[in] bias_shift amount of left-shift for bias
  45. * @param[in] out_shift amount of right-shift for output
  46. * @param[in] bias pointer to bias
  47. * @param[in,out] pOut pointer to output vector
  48. * @param[in,out] vec_buffer pointer to buffer space for input
  49. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  50. *
  51. * @details
  52. *
  53. * <b>Buffer size:</b>
  54. *
  55. * vec_buffer size: dim_vec
  56. *
  57. * This opt function is designed to work with interleaved weight
  58. * matrix. The vector input is assumed in q7_t format, we call
  59. * arm_q7_to_q15_no_shift_shuffle function to expand into
  60. * q15_t format with certain weight re-ordering, refer to the function
  61. * comments for more details.
  62. * Here we use only one pointer to read 4 rows in the weight
  63. * matrix. So if the original q7_t matrix looks like this:
  64. *
  65. * | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
  66. *
  67. * | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
  68. *
  69. * | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
  70. *
  71. * | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
  72. *
  73. * | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
  74. *
  75. * | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
  76. *
  77. *
  78. * We operates on multiple-of-4 rows, so the first four rows becomes
  79. *
  80. * | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
  81. *
  82. * | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
  83. *
  84. * | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
  85. *
  86. * So within the kernel, we first read the re-ordered vector in as:
  87. *
  88. * | b1 | b3 | and | b2 | b4 |
  89. *
  90. * the four q31_t weights will look like
  91. *
  92. * | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
  93. *
  94. * | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
  95. *
  96. * The column left over will be in-order.
  97. * which is:
  98. *
  99. * | a17 | a27 | a37 | a47 |
  100. *
  101. * For the left-over rows, we do 1x1 computation, so the data remains
  102. * as its original order.
  103. *
  104. * So the stored weight matrix looks like this:
  105. *
  106. * | a11 | a21 | a13 | a23 | a31 | a41 |
  107. *
  108. * | a33 | a43 | a12 | a22 | a14 | a24 |
  109. *
  110. * | a32 | a42 | a34 | a44 | a15 | a25 |
  111. *
  112. * | a35 | a45 | a16 | a26 | a36 | a46 |
  113. *
  114. * | a17 | a27 | a37 | a47 | a51 | a52 |
  115. *
  116. * | a53 | a54 | a55 | a56 | a57 | a61 |
  117. *
  118. * | a62 | a63 | a64 | a65 | a66 | a67 |
  119. *
  120. *
  121. */
  122. arm_status
  123. arm_fully_connected_q7_opt(const q7_t * pV,
  124. const q7_t * pM,
  125. const uint16_t dim_vec,
  126. const uint16_t num_of_rows,
  127. const uint16_t bias_shift,
  128. const uint16_t out_shift,
  129. const q7_t * bias,
  130. q7_t * pOut,
  131. q15_t * vec_buffer)
  132. {
  133. #if defined (ARM_MATH_DSP)
  134. /* Run the following code for Cortex-M4 and Cortex-M7 */
  135. const q7_t *pB = pM;
  136. q7_t *pO = pOut;
  137. const q7_t *pBias = bias;
  138. q15_t *pA;
  139. uint16_t rowCnt = num_of_rows >> 2;
  140. arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
  141. while (rowCnt)
  142. {
  143. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  144. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  145. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  146. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  147. uint16_t colCnt = dim_vec >> 2;
  148. pA = vec_buffer;
  149. #ifdef USE_INTRINSIC
  150. #ifndef ARM_MATH_BIG_ENDIAN
  151. while (colCnt)
  152. {
  153. q31_t inM11, inM12, inM13, inM14;
  154. q31_t inV;
  155. inV = *__SIMD32(pA)++;
  156. inM11 = *__SIMD32(pB)++;
  157. inM12 = __SXTB16(__ROR(inM11, 8));
  158. inM11 = __SXTB16(inM11);
  159. sum = __SMLAD(inM11, inV, sum);
  160. sum2 = __SMLAD(inM12, inV, sum2);
  161. inM13 = *__SIMD32(pB)++;
  162. inM14 = __SXTB16(__ROR(inM13, 8));
  163. inM13 = __SXTB16(inM13);
  164. sum3 = __SMLAD(inM13, inV, sum3);
  165. sum4 = __SMLAD(inM14, inV, sum4);
  166. inV = *__SIMD32(pA)++;
  167. inM11 = *__SIMD32(pB)++;
  168. inM12 = __SXTB16(__ROR(inM11, 8));
  169. inM11 = __SXTB16(inM11);
  170. sum = __SMLAD(inM11, inV, sum);
  171. sum2 = __SMLAD(inM12, inV, sum2);
  172. inM13 = *__SIMD32(pB)++;
  173. inM14 = __SXTB16(__ROR(inM13, 8));
  174. inM13 = __SXTB16(inM13);
  175. sum3 = __SMLAD(inM13, inV, sum3);
  176. sum4 = __SMLAD(inM14, inV, sum4);
  177. colCnt--;
  178. }
  179. #else
  180. while (colCnt)
  181. {
  182. q31_t inM11, inM12, inM13, inM14;
  183. q31_t inV;
  184. inV = *__SIMD32(pA)++;
  185. inM11 = *__SIMD32(pB)++;
  186. inM12 = __SXTB16(__ROR(inM11, 8));
  187. inM11 = __SXTB16(inM11);
  188. sum = __SMLAD(inM12, inV, sum);
  189. sum2 = __SMLAD(inM11, inV, sum2);
  190. inM13 = *__SIMD32(pB)++;
  191. inM14 = __SXTB16(__ROR(inM13, 8));
  192. inM13 = __SXTB16(inM13);
  193. sum3 = __SMLAD(inM14, inV, sum3);
  194. sum4 = __SMLAD(inM13, inV, sum4);
  195. inV = *__SIMD32(pA)++;
  196. inM11 = *__SIMD32(pB)++;
  197. inM12 = __SXTB16(__ROR(inM11, 8));
  198. inM11 = __SXTB16(inM11);
  199. sum = __SMLAD(inM12, inV, sum);
  200. sum2 = __SMLAD(inM11, inV, sum2);
  201. inM13 = *__SIMD32(pB)++;
  202. inM14 = __SXTB16(__ROR(inM13, 8));
  203. inM13 = __SXTB16(inM13);
  204. sum3 = __SMLAD(inM14, inV, sum3);
  205. sum4 = __SMLAD(inM13, inV, sum4);
  206. colCnt--;
  207. }
  208. #endif /* ARM_MATH_BIG_ENDIAN */
  209. #else
  210. /*
  211. * register needed:
  212. * loop counter: colCnt
  213. * accumulators: sum, sum2, sum3, sum4
  214. * pointers: pB, pA
  215. * weight data: inM11, inM12, inM13, inM14
  216. * activation data: inV
  217. */
  218. #ifndef ARM_MATH_BIG_ENDIAN
  219. asm volatile ("COL_LOOP_%=:\n"
  220. "ldr.w r4, [%[pA]], #8\n"
  221. "ldr.w r1, [%[pB]], #16\n"
  222. "mov.w r0, r1, ror #8\n"
  223. "sxtb16 r0, r0\n"
  224. "sxtb16 r1, r1\n"
  225. "smlad %[sum], r4, r1, %[sum]\n"
  226. "smlad %[sum2], r4, r0, %[sum2]\n"
  227. "ldr.w r3, [%[pB], #-12]\n"
  228. "mov.w r2, r3, ror #8\n"
  229. "sxtb16 r2, r2\n"
  230. "sxtb16 r3, r3\n"
  231. "smlad %[sum3], r4, r3, %[sum3]\n"
  232. "smlad %[sum4], r4, r2, %[sum4]\n"
  233. "ldr.w r4, [%[pA], #-4]\n"
  234. "ldr.w r1, [%[pB], #-8]\n"
  235. "mov.w r0, r1, ror #8\n"
  236. "sxtb16 r0, r0\n"
  237. "sxtb16 r1, r1\n"
  238. "smlad %[sum], r4, r1, %[sum]\n"
  239. "smlad %[sum2], r4, r0, %[sum2]\n"
  240. "ldr.w r3, [%[pB], #-4]\n"
  241. "mov.w r2, r3, ror #8\n"
  242. "sxtb16 r2, r2\n"
  243. "sxtb16 r3, r3\n"
  244. "smlad %[sum3], r4, r3, %[sum3]\n"
  245. "smlad %[sum4], r4, r2, %[sum4]\n"
  246. "subs %[colCnt], #1\n"
  247. "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  248. [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  249. [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  250. #else
  251. asm volatile ("COL_LOOP_%=:\n"
  252. "ldr.w r4, [%[pA]], #8\n"
  253. "ldr.w r1, [%[pB]], #16\n"
  254. "mov.w r0, r1, ror #8\n"
  255. "sxtb16 r0, r0\n"
  256. "sxtb16 r1, r1\n"
  257. "smlad %[sum], r4, r0, %[sum]\n"
  258. "smlad %[sum2], r4, r1, %[sum2]\n"
  259. "ldr.w r3, [%[pB], #-12]\n"
  260. "mov.w r2, r3, ror #8\n"
  261. "sxtb16 r2, r2\n"
  262. "sxtb16 r3, r3\n"
  263. "smlad %[sum3], r4, r2, %[sum3]\n"
  264. "smlad %[sum4], r4, r3, %[sum4]\n"
  265. "ldr.w r4, [%[pA], #-4]\n"
  266. "ldr.w r1, [%[pB], #-8]\n"
  267. "mov.w r0, r1, ror #8\n"
  268. "sxtb16 r0, r0\n"
  269. "sxtb16 r1, r1\n"
  270. "smlad %[sum], r4, r0, %[sum]\n"
  271. "smlad %[sum2], r4, r1, %[sum2]\n"
  272. "ldr.w r3, [%[pB], #-4]\n"
  273. "mov.w r2, r3, ror #8\n"
  274. "sxtb16 r2, r2\n"
  275. "sxtb16 r3, r3\n"
  276. "smlad %[sum3], r4, r2, %[sum3]\n"
  277. "smlad %[sum4], r4, r3, %[sum4]\n"
  278. "subs %[colCnt], #1\n"
  279. "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  280. [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  281. [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  282. #endif /* ARM_MATH_BIG_ENDIAN */
  283. #endif /* USE_INTRINSIC */
  284. colCnt = dim_vec & 0x3;
  285. while (colCnt)
  286. {
  287. q15_t inV = *pA++;
  288. q7_t inM = *pB++;
  289. q7_t inM2 = *pB++;
  290. q7_t inM3 = *pB++;
  291. q7_t inM4 = *pB++;
  292. sum += inV * inM;
  293. sum2 += inV * inM2;
  294. sum3 += inV * inM3;
  295. sum4 += inV * inM4;
  296. colCnt--;
  297. } /* while over colCnt */
  298. *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
  299. *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
  300. *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
  301. *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
  302. /* adjust the pointers and counters */
  303. rowCnt--;
  304. }
  305. /* left-over part of the rows */
  306. rowCnt = num_of_rows & 0x3;
  307. while (rowCnt)
  308. {
  309. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  310. uint16_t colCnt = dim_vec >> 2;
  311. pA = vec_buffer;
  312. while (colCnt)
  313. {
  314. q31_t inV1, inV2, inM11, inM12;
  315. pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
  316. inV1 = *__SIMD32(pA)++;
  317. sum = __SMLAD(inV1, inM11, sum);
  318. inV2 = *__SIMD32(pA)++;
  319. sum = __SMLAD(inV2, inM12, sum);
  320. colCnt--;
  321. }
  322. /* left-over of the vector */
  323. colCnt = dim_vec & 0x3;
  324. while (colCnt)
  325. {
  326. q15_t inV = *pA++;
  327. q7_t inM = *pB++;
  328. sum += inV * inM;
  329. colCnt--;
  330. }
  331. *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
  332. rowCnt--;
  333. }
  334. #else
  335. /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  336. uint16_t rowCnt = num_of_rows >> 2;
  337. const q7_t *pB = pM;
  338. const q7_t *pA;
  339. q7_t *pO = pOut;
  340. const q7_t *pBias = bias;
  341. while (rowCnt)
  342. {
  343. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  344. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  345. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  346. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  347. uint16_t colCnt = dim_vec >> 2;
  348. pA = pV;
  349. while (colCnt)
  350. {
  351. q7_t inA1 = *pA++;
  352. q7_t inA3 = *pA++;
  353. q7_t inA2 = *pA++;
  354. q7_t inA4 = *pA++;
  355. q7_t inB1 = *pB++;
  356. q7_t inB3 = *pB++;
  357. q7_t inB2 = *pB++;
  358. q7_t inB4 = *pB++;
  359. sum += inA1 * inB1 + inA2 * inB2;
  360. sum2 += inA1 * inB3 + inA2 * inB4;
  361. inB1 = *pB++;
  362. inB3 = *pB++;
  363. inB2 = *pB++;
  364. inB4 = *pB++;
  365. sum3 += inA1 * inB1 + inA2 * inB2;
  366. sum4 += inA1 * inB3 + inA2 * inB4;
  367. inB1 = *pB++;
  368. inB3 = *pB++;
  369. inB2 = *pB++;
  370. inB4 = *pB++;
  371. sum += inA3 * inB1 + inA4 * inB2;
  372. sum2 += inA3 * inB3 + inA4 * inB4;
  373. inB1 = *pB++;
  374. inB3 = *pB++;
  375. inB2 = *pB++;
  376. inB4 = *pB++;
  377. sum3 += inA3 * inB1 + inA4 * inB2;
  378. sum4 += inA3 * inB3 + inA4 * inB4;
  379. colCnt--;
  380. }
  381. colCnt = dim_vec & 0x3;
  382. while (colCnt)
  383. {
  384. q7_t inA = *pA++;
  385. q7_t inB = *pB++;
  386. sum += inA * inB;
  387. inB = *pB++;
  388. sum2 += inA * inB;
  389. inB = *pB++;
  390. sum3 += inA * inB;
  391. inB = *pB++;
  392. sum4 += inA * inB;
  393. colCnt--;
  394. }
  395. *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
  396. *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
  397. *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
  398. *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
  399. rowCnt--;
  400. }
  401. rowCnt = num_of_rows & 0x3;
  402. while (rowCnt)
  403. {
  404. int ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  405. int j;
  406. pA = pV;
  407. for (j = 0; j < dim_vec; j++)
  408. {
  409. q7_t inA = *pA++;
  410. q7_t inB = *pB++;
  411. ip_out += inA * inB;
  412. }
  413. *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
  414. rowCnt--;
  415. }
  416. #endif /* ARM_MATH_DSP */
  417. /* Return to ARM_MATH_SUCCESS */
  418. return (ARM_MATH_SUCCESS);
  419. }
  420. /**
  421. * @} end of FC group
  422. */