bmcv_gemm_ext
该接口可以实现 fp32/fp16 类型矩阵的通用乘法计算,如下公式:
\[Y = \alpha\times A\times B + \beta\times C\]
其中,A、B、C、Y均为矩阵,\(\alpha\) 和 \(\beta\) 均为常系数
接口的格式如下:
bm_status_t bmcv_gemm_ext(bm_handle_t handle, bool is_A_trans, bool is_B_trans, int M, int N, int K, float alpha, bm_device_mem_t A, bm_device_mem_t B, float beta, bm_device_mem_t C, bm_device_mem_t Y, bm_image_data_format_ext input_dtype, bm_image_data_format_ext output_dtype);
输入参数说明:
bm_handle_t handle
输入参数。bm_handle 句柄
bool is_A_trans
输入参数。设定矩阵 A 是否转置
bool is_B_trans
输入参数。设定矩阵 B 是否转置
int M
输入参数。矩阵 A、C、Y 的行数
int N
输入参数。矩阵 B、C、Y 的列数
int K
输入参数。矩阵 A 的列数和矩阵 B 的行数
float alpha
输入参数。数乘系数
bm_device_mem_t A
输入参数。根据数据存放位置保存左矩阵 A 数据的 device 地址,需在使用前完成数据s2d搬运。
bm_device_mem_t B
输入参数。根据数据存放位置保存右矩阵 B 数据的 device 地址,需在使用前完成数据s2d搬运。
float beta
输入参数。数乘系数。
bm_device_mem_t C
输入参数。根据数据存放位置保存矩阵 C 数据的 device 地址,需在使用前完成数据s2d搬运。
bm_device_mem_t Y
输出参数。矩阵 Y 数据的 device 地址,保存输出结果。
bm_image_data_format_ext input_dtype
输入参数。输入矩阵A、B、C的数据类型。支持输入FP16-输出FP16或FP32,输入FP32-输出FP32。
bm_image_data_format_ext output_dtype
输入参数。输出矩阵Y的数据类型。
返回值说明:
BM_SUCCESS: 成功
其他:失败
注意:
该接口仅支持BM1684X。
该接口在FP16输入、A矩阵转置的情况下,M仅支持小于等于64的取值。
该接口不支持FP32输入且FP16输出。
示例代码
int M = 3, N = 4, K = 5; float alpha = 0.4, beta = 0.6; bool is_A_trans = false; bool is_B_trans = false; float *A = new float[M * K]; float *B = new float[N * K]; float *C = new float[M * N]; memset(A, 0x11, M * K * sizeof(float)); memset(B, 0x22, N * K * sizeof(float)); memset(C, 0x33, M * N * sizeof(float)); bm_device_mem_t input_dev_buffer[3]; bm_device_mem_t output_dev_buffer[1]; bm_malloc_device_byte(handle, &input_dev_buffer[0], M * K * sizeof(float)); bm_malloc_device_byte(handle, &input_dev_buffer[1], N * K * sizeof(float)); bm_malloc_device_byte(handle, &input_dev_buffer[2], M * N * sizeof(float)); bm_memcpy_s2d(handle, input_dev_buffer[0], (void *)A); bm_memcpy_s2d(handle, input_dev_buffer[1], (void *)B); bm_memcpy_s2d(handle, input_dev_buffer[2], (void *)C); bm_malloc_device_byte(handle, &output_dev_buffer[0], M * N * sizeof(float)); bm_image_data_format_ext in_dtype = DATA_TYPE_EXT_FLOAT32; bm_image_data_format_ext out_dtype = DATA_TYPE_EXT_FLOAT32; bmcv_gemm_ext(handle, is_A_trans, is_B_trans, M, N, K, alpha, input_dev_buffer[0], input_dev_buffer[1], beta, input_dev_buffer[2], output_dev_buffer[0], in_dtype, out_dtype); delete A; delete B; delete C; delete Y; for (int i = 0; i < 3; i++) { bm_free_device(handle, input_dev_buffer[i]); } bm_free_device(handle, output_dev_buffer[0]);