bmcv_faiss_indexPQ_encode
该接口输入 vectors 和 centroids 计算距离表并排序,输出 vectors 的量化编码。
处理器型号支持:
该接口仅支持BM1684X。
接口形式:
bm_status_t bmcv_faiss_indexPQ_encode( bm_handle_t handle, bm_device_mem_t vector_input_dev, bm_device_mem_t centroids_input_dev, bm_device_mem_t buffer_table_dev, bm_device_mem_t codes_output_dev, int encode_vec_num, int vec_dims, int slice_num, int centroids_num, int IP_metric);
输入参数说明:
bm_handle_t handle
输入参数。bm_handle 句柄。
bm_device_mem_t vector_input_dev
输入参数。存放待编码向量的 device 空间。
bm_device_mem_t centroids_input_dev
输入参数。存储聚类中心数据的 deivce 空间。
bm_device_mem_t buffer_table_dev
输入参数。存放计算出的距离表的缓存空间。
bm_device_mem_t codes_output_dev
输出参数。存放向量编码结果的 device 空间。
int encode_vec_num
输入参数。待编码向量的个数。
int vec_dims
输入参数。原始向量的维度。
int slice_num
输入参数。原始维度切分数量。
int centroids_num
输入参数。聚类中心的数量。
int IP_metric
输入参数。0 表示L2距离计算; 1 表示IP距离计算。
返回值说明:
BM_SUCCESS: 成功
其他:失败
注意事项:
1、输入数据 (查询向量) 和聚类中心的数据类型为 float,输出向量编码的数据类型为 uint8,存储在设备内存上。
2、buffer_table 的大小为 slice_num * centroids_num,数据类型为float。
示例代码
#include "bmcv_api_ext.h" #include "test_misc.h" #include <stdio.h> #include <stdlib.h> #include <time.h> #include <assert.h> #include <sys/time.h> #define BMLIB_SAFE_CALL(cmd) assert(cmd == BM_SUCCESS) int main() { int vec_dims = 256; int encode_vec_num = 1; int slice_m = 32; int ksub = 256; int dsub = vec_dims / slice_m; int input_dtype = 5; // 5: float int IP_metric = 0; struct timespec tp; clock_gettime(CLOCK_REALTIME, &tp); unsigned int seed = tp.tv_nsec; bm_handle_t handle; bm_status_t ret = bm_dev_request(&handle, 0); if (ret != BM_SUCCESS) { printf("request dev failed\n"); return BM_ERR_FAILURE; } srand(seed); float *centroids_input_sys_fp32 = (float*)malloc(slice_m * ksub * dsub * sizeof(float)); unsigned char *nxcodes_input_sys = (unsigned char*)malloc(encode_vec_num * vec_dims); unsigned char *output_codes_sys = (unsigned char*)malloc(encode_vec_num * slice_m); for (int i = 0; i < slice_m; i++) { for (int j = 0; j < ksub; j++) { for (int n = 0; n < dsub; n++) { float value = (float)rand() / RAND_MAX * 20.0 - 10.0; centroids_input_sys_fp32[i * dsub * ksub + j * dsub + n] = value; } } } for (int i = 0; i < encode_vec_num; i++) { for (int j = 0; j < slice_m; j++) { nxcodes_input_sys[i * slice_m + j] = rand() % 256; } } int centroids_size = slice_m * ksub * dsub * dtype_size((data_type_t)input_dtype); int nxcodes_size = encode_vec_num * vec_dims * dtype_size((data_type_t)input_dtype);; int buffer_table_size = slice_m * ksub * dtype_size((data_type_t)input_dtype);; int output_codes_size = encode_vec_num * slice_m; bm_device_mem_t centroids_input_dev, nxcodes_input_dev, buffer_table_dev, codes_output_dev; BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, ¢roids_input_dev, centroids_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &nxcodes_input_dev, nxcodes_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &buffer_table_dev, buffer_table_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &codes_output_dev, output_codes_size)); BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, centroids_input_dev, centroids_input_sys_fp32)); BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, nxcodes_input_dev, nxcodes_input_sys)); struct timeval t1, t2; gettimeofday(&t1, NULL); ret = bmcv_faiss_indexPQ_encode(handle, nxcodes_input_dev, centroids_input_dev, buffer_table_dev, codes_output_dev, encode_vec_num, vec_dims, slice_m, ksub, IP_metric); gettimeofday(&t2, NULL); printf("TPU using time(us): %ld(us)\n", (t2.tv_sec - t1.tv_sec) * 1000000 + t2.tv_usec - t1.tv_usec); printf("TPU using time(ms): %ld(ms)\n", ((t2.tv_sec - t1.tv_sec) * 1000000 + t2.tv_usec - t1.tv_usec) / 1000); if(ret != BM_SUCCESS){ bm_free_device(handle, centroids_input_dev); bm_free_device(handle, nxcodes_input_dev); bm_free_device(handle, buffer_table_dev); bm_free_device(handle, codes_output_dev); free(centroids_input_sys_fp32); free(nxcodes_input_sys); free(output_codes_sys); bm_dev_free(handle); return BM_ERR_FAILURE; } BMLIB_SAFE_CALL(bm_memcpy_d2s(handle, output_codes_sys, codes_output_dev)); printf("finish encode\n"); bm_free_device(handle, centroids_input_dev); bm_free_device(handle, nxcodes_input_dev); bm_free_device(handle, buffer_table_dev); bm_free_device(handle, codes_output_dev); free(centroids_input_sys_fp32); free(nxcodes_input_sys); free(output_codes_sys); bm_dev_free(handle); return 0; }