bmcv_gemm
该接口可以实现 float32 类型矩阵的通用乘法计算,如下公式:
\[C = \alpha\times A\times B + \beta\times C\]
其中,A、B、C均为矩阵,\(\alpha\) 和 \(\beta\) 均为常系数
接口的格式如下:
bm_status_t bmcv_gemm(bm_handle_t handle, bool is_A_trans, bool is_B_trans, int M, int N, int K, float alpha, bm_device_mem_t A, int lda, bm_device_mem_t B, int ldb, float beta, bm_device_mem_t C, int ldc);
输入参数说明:
bm_handle_t handle
输入参数。bm_handle 句柄
bool is_A_trans
输入参数。设定矩阵 A 是否转置
bool is_B_trans
输入参数。设定矩阵 B 是否转置
int M
输入参数。矩阵 A 和矩阵 C 的行数
int N
输入参数。矩阵 B 和矩阵 C 的列数
int K
输入参数。矩阵 A 的列数和矩阵 B 的行数
float alpha
输入参数。数乘系数
bm_device_mem_t A
输入参数。根据数据存放位置保存左矩阵 A 数据的 device 地址或者 host 地址。如果数据存放于 host 空间则内部会自动完成 s2d 的搬运
int lda
输入参数。矩阵 A 的 leading dimension, 即第一维度的大小,在行与行之间没有stride的情况下即为 A 的列数(不做转置)或行数(做转置)
bm_device_mem_t B
输入参数。根据数据存放位置保存右矩阵 B 数据的 device 地址或者 host 地址。如果数据存放于 host 空间则内部会自动完成 s2d 的搬运。
int ldb
输入参数。矩阵 C 的 leading dimension, 即第一维度的大小,在行与行之间没有stride的情况下即为 B 的列数(不做转置)或行数(做转置。
float beta
输入参数。数乘系数。
bm_device_mem_t C
输出参数。根据数据存放位置保存矩阵 C 数据的 device 地址或者 host 地址。如果是 host 地址,则当beta不为0时,计算前内部会自动完成 s2d 的搬运,计算后再自动完成 d2s 的搬运。
int ldc
输入参数。矩阵 C 的 leading dimension, 即第一维度的大小,在行与行之间没有stride的情况下即为 C 的列数。
返回值说明:
BM_SUCCESS: 成功
其他:失败
示例代码
int M = 3, N = 4, K = 5; float alpha = 0.4, beta = 0.6; bool is_A_trans = false; bool is_B_trans = false; float *A = new float[M * K]; float *B = new float[N * K]; float *C = new float[M * N]; memset(A, 0x11, M * K * sizeof(float)); memset(B, 0x22, N * K * sizeof(float)); memset(C, 0x33, M * N * sizeof(float)); bmcv_gemm(handle, is_A_trans, is_B_trans, M, N, K, alpha, bm_mem_from_system((void *)A), is_A_trans ? M : K, bm_mem_from_system((void *)B), is_B_trans ? K : N, beta, bm_mem_from_system((void *)C), N); delete A; delete B; delete C;