bmcv_matmul
该接口可以实现 8-bit 数据类型矩阵的乘法计算,如下公式:
(1)\[C = (A\times B) >> rshift\_bit\]
或者
(2)\[C = alpha \times (A\times B) + beta\]
其中,
A 是输入的左矩阵,其数据类型可以是 unsigned char 或者 signed char 类型的 8-bit 数据,大小为(M,K);
B 是输入的右矩阵,其数据类型可以是 unsigned char 或者 signed char 类型的 8-bit 数据,大小为(K,N);
C 是输出的结果矩阵, 其数据类型长度可以是 int8、int16 或者 float32,用户配置决定。
rshift_bit 是矩阵乘积的右移数,当 C 是 int8 或者 int16 时才有效,由于矩阵的乘积有可能会超出 8-bit 或者 16-bit 的范围,所以用户可以配置一定的右移数,通过舍弃部分精度来防止溢出。
alpha和beta 是 float32 的常系数,当 C 是 float32 时才有效。
接口的格式如下:
bm_status_t bmcv_matmul(bm_handle_t handle, int M, int N, int K, bm_device_mem_t A, bm_device_mem_t B, bm_device_mem_t C, int A_sign, int B_sign, int rshift_bit, int result_type, bool is_B_trans, float alpha = 1, float beta = 0);
处理器型号支持:
该接口支持BM1684/BM1684X。
输入参数说明:
bm_handle_t handle
输入参数。bm_handle 句柄
int M
输入参数。矩阵 A 和矩阵 C 的行数
int N
输入参数。矩阵 B 和矩阵 C 的列数
int K
输入参数。矩阵 A 的列数和矩阵 B 的行数
bm_device_mem_t A
输入参数。根据左矩阵 A 数据存放位置保存其 device 地址或者 host 地址。如果数据存放于 host 空间则内部会自动完成 s2d 的搬运
bm_device_mem_t B
输入参数。根据右矩阵 B 数据存放位置保存其 device 地址或者 host 地址。如果数据存放于 host 空间则内部会自动完成 s2d 的搬运。
bm_device_mem_t C
输出参数。根据矩阵 C 数据存放位置保存其 device 地址或者 host 地址。如果是 host 地址,则当beta不为0时,计算前内部会自动完成 s2d 的搬运,计算后再自动完成 d2s 的搬运。
int A_sign
输入参数。左矩阵A的符号,1 表示有符号,0 表示无符号。
int B_sign
输入参数。右矩阵B的符号,1 表示有符号,0 表示无符号。
int rshift_bit
输入参数。矩阵乘积的右移数,为非负数。只有当 result_type 等于 0 或者 1 时才有效。
int result_type
输入参数。输出的结果矩阵数据类型,0表示是 int8,1表示int16, 2表示 float32。
bool is_B_trans
输入参数。输入右矩阵B是否需要计算前做转置。
float alpha
常系数,输入矩阵 A 和 B 相乘之后再乘上该系数,只有当 result_type 等于2时才有效,默认值为1。
float beta
常系数,在输出结果矩阵 C 之前,加上该偏移量,只有当 result_type 等于2时才有效,默认值为0。
返回值说明:
BM_SUCCESS: 成功
其他:失败
示例代码
int M = 3, N = 4, K = 5; int result_type = 1; bool is_B_trans = false; int rshift_bit = 0; char *A = new char[M * K]; char *B = new char[N * K]; short *C = new short[M * N]; memset(A, 0x11, M * K * sizeof(char)); memset(B, 0x22, N * K * sizeof(char)); bmcv_matmul(handle, M, N, K, bm_mem_from_system((void *)A), bm_mem_from_system((void *)B), bm_mem_from_system((void *)C), 1, 1, rshift_bit, result_type, is_B_trans); delete A; delete B; delete C;