arm_fully_connected_q15_opt.c 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. /*
  2. * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_fully_connected_q15_opt.c
  21. * Description: Q15 opt fully-connected layer function
  22. *
  23. * $Date: 17. January 2018
  24. * $Revision: V.1.0.0
  25. *
  26. * Target Processor: Cortex-M cores
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_math.h"
  30. #include "arm_nnfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup FC
  36. * @{
  37. */
  38. /**
  39. * @brief Q15 opt fully-connected layer function
  40. * @param[in] pV pointer to input vector
  41. * @param[in] pM pointer to matrix weights
  42. * @param[in] dim_vec length of the vector
  43. * @param[in] num_of_rows number of rows in weight matrix
  44. * @param[in] bias_shift amount of left-shift for bias
  45. * @param[in] out_shift amount of right-shift for output
  46. * @param[in] bias pointer to bias
  47. * @param[in,out] pOut pointer to output vector
  48. * @param[in,out] vec_buffer pointer to buffer space for input
  49. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  50. *
  51. *
  52. * @details
  53. *
  54. * <b>Buffer size:</b>
  55. *
  56. * vec_buffer size: 0
  57. *
  58. * Here we use only one pointer to read 4 rows in the weight
  59. * matrix. So if the original matrix looks like this:
  60. *
  61. * | a11 | a12 | a13 |
  62. *
  63. * | a21 | a22 | a23 |
  64. *
  65. * | a31 | a32 | a33 |
  66. *
  67. * | a41 | a42 | a43 |
  68. *
  69. * | a51 | a52 | a53 |
  70. *
  71. * | a61 | a62 | a63 |
  72. *
  73. * We operates on multiple-of-4 rows, so the first four rows becomes
  74. *
  75. * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  76. *
  77. * | a13 | a23 | a33 | a43 |
  78. *
  79. * Remaining rows are kept the same original order.
  80. *
  81. * So the stored weight matrix looks like this:
  82. *
  83. *
  84. * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  85. *
  86. * | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
  87. *
  88. * | a62 | a63 |
  89. */
  90. arm_status
  91. arm_fully_connected_q15_opt(const q15_t * pV,
  92. const q15_t * pM,
  93. const uint16_t dim_vec,
  94. const uint16_t num_of_rows,
  95. const uint16_t bias_shift,
  96. const uint16_t out_shift,
  97. const q15_t * bias,
  98. q15_t * pOut,
  99. q15_t * vec_buffer)
  100. {
  101. #if defined (ARM_MATH_DSP)
  102. /* Run the following code for Cortex-M4 and Cortex-M7 */
  103. const q15_t *pB = pM;
  104. q15_t *pO = pOut;
  105. const q15_t *pBias = bias;
  106. const q15_t *pA = pV;
  107. uint16_t rowCnt = num_of_rows >> 2;
  108. while (rowCnt)
  109. {
  110. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  111. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  112. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  113. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  114. uint16_t colCnt = dim_vec >> 1;
  115. pA = pV;
  116. #ifdef USE_INTRINSIC
  117. while (colCnt)
  118. {
  119. q31_t inM11, inM12, inM13, inM14;
  120. q31_t inV;
  121. inV = *__SIMD32(pA)++;
  122. inM11 = *__SIMD32(pB)++;
  123. sum = __SMLAD(inV, inM11, sum);
  124. inM12 = *__SIMD32(pB)++;
  125. sum2 = __SMLAD(inV, inM12, sum2);
  126. inM13 = *__SIMD32(pB)++;
  127. sum3 = __SMLAD(inV, inM13, sum3);
  128. inM14 = *__SIMD32(pB)++;
  129. sum4 = __SMLAD(inV, inM14, sum4);
  130. colCnt--;
  131. }
  132. #else
  133. /*
  134. * register needed:
  135. * loop counter: colCnt
  136. * accumulators: sum, sum2, sum3, sum4
  137. * pointers: pB, pA
  138. * weight data: inM11, inM12, inM13, inM14
  139. * activation data: inV
  140. */
  141. asm volatile ("COL_LOOP_%=:\n"
  142. "ldr.w r4, [%[pA]], #4\n"
  143. "ldr.w r0, [%[pB]], #16\n"
  144. "smlad %[sum], r4, r0, %[sum]\n"
  145. "ldr.w r1, [%[pB] , #-12]\n"
  146. "smlad %[sum2], r4, r1, %[sum2]\n"
  147. "ldr.w r2, [%[pB] , #-8]\n"
  148. "smlad %[sum3], r4, r2, %[sum3]\n"
  149. "ldr.w r3, [%[pB] , #-4]\n"
  150. "smlad %[sum4], r4, r3, %[sum4]\n"
  151. "subs %[colCnt], #1\n"
  152. "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  153. [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  154. [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  155. #endif /* USE_INTRINSIC */
  156. colCnt = dim_vec & 0x1;
  157. while (colCnt)
  158. {
  159. q15_t inV = *pA++;
  160. q15_t inM = *pB++;
  161. q15_t inM2 = *pB++;
  162. q15_t inM3 = *pB++;
  163. q15_t inM4 = *pB++;
  164. sum += inV * inM;
  165. sum2 += inV * inM2;
  166. sum3 += inV * inM3;
  167. sum4 += inV * inM4;
  168. colCnt--;
  169. } /* while over colCnt */
  170. *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  171. *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
  172. *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
  173. *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
  174. /* adjust the pointers and counters */
  175. rowCnt--;
  176. }
  177. /* left-over part of the rows */
  178. rowCnt = num_of_rows & 0x3;
  179. while (rowCnt)
  180. {
  181. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  182. uint16_t colCnt = dim_vec >> 2;
  183. pA = pV;
  184. while (colCnt)
  185. {
  186. q31_t inV1, inV2, inM1, inM2;
  187. inM1 = *__SIMD32(pB)++;
  188. inV1 = *__SIMD32(pA)++;
  189. sum = __SMLAD(inV1, inM1, sum);
  190. inM2 = *__SIMD32(pB)++;
  191. inV2 = *__SIMD32(pA)++;
  192. sum = __SMLAD(inV2, inM2, sum);
  193. colCnt--;
  194. }
  195. /* left-over of the vector */
  196. colCnt = dim_vec & 0x3;
  197. while (colCnt)
  198. {
  199. q15_t inV = *pA++;
  200. q15_t inM = *pB++;
  201. sum += inV * inM;
  202. colCnt--;
  203. }
  204. *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  205. rowCnt--;
  206. }
  207. #else
  208. /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  209. uint16_t rowCnt = num_of_rows >> 2;
  210. const q15_t *pB = pM;
  211. const q15_t *pA;
  212. q15_t *pO = pOut;
  213. const q15_t *pBias = bias;
  214. while (rowCnt)
  215. {
  216. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  217. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  218. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  219. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  220. uint16_t colCnt = dim_vec >> 1;
  221. pA = pV;
  222. while (colCnt)
  223. {
  224. q15_t inA1 = *pA++;
  225. q15_t inA2 = *pA++;
  226. q15_t inB1 = *pB++;
  227. q15_t inB2 = *pB++;
  228. sum += inA1 * inB1 + inA2 * inB2;
  229. inB1 = *pB++;
  230. inB2 = *pB++;
  231. sum2 += inA1 * inB1 + inA2 * inB2;
  232. inB1 = *pB++;
  233. inB2 = *pB++;
  234. sum3 += inA1 * inB1 + inA2 * inB2;
  235. inB1 = *pB++;
  236. inB2 = *pB++;
  237. sum4 += inA1 * inB1 + inA2 * inB2;
  238. colCnt--;
  239. }
  240. colCnt = dim_vec & 0x1;
  241. while (colCnt)
  242. {
  243. q15_t inA = *pA++;
  244. q15_t inB = *pB++;
  245. sum += inA * inB;
  246. inB = *pB++;
  247. sum2 += inA * inB;
  248. inB = *pB++;
  249. sum3 += inA * inB;
  250. inB = *pB++;
  251. sum4 += inA * inB;
  252. colCnt--;
  253. }
  254. *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
  255. *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
  256. *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
  257. *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
  258. rowCnt--;
  259. }
  260. rowCnt = num_of_rows & 0x3;
  261. while (rowCnt)
  262. {
  263. int ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  264. int j;
  265. pA = pV;
  266. for (j = 0; j < dim_vec; j++)
  267. {
  268. q15_t inA = *pA++;
  269. q15_t inB = *pB++;
  270. ip_out += inA * inB;
  271. }
  272. *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
  273. rowCnt--;
  274. }
  275. #endif /* ARM_MATH_DSP */
  276. /* Return to ARM_MATH_SUCCESS */
  277. return (ARM_MATH_SUCCESS);
  278. }
  279. /**
  280. * @} end of FC group
  281. */