bmcv_faiss_indexPQ_SDC
该接口通过检索向量编码和底库编码在 sdc_table 中查表并累加,输出前 K (sort_cnt) 个最匹配的向量索引及其对应的距离。
处理器型号支持:
该接口仅支持BM1684X。
接口形式:
bm_status_t bmcv_faiss_indexPQ_SDC( bm_handle_t handle, bm_device_mem_t sdc_table_input_dev, bm_device_mem_t nxcodes_input_dev, bm_device_mem_t nycodes_input_dev, bm_device_mem_t distance_output_dev, bm_device_mem_t index_output_dev, int slice_num, int centroids_num, int database_num, int query_num, int sort_cnt, int IP_metric);
输入参数说明:
bm_handle_t handle
输入参数。bm_handle 句柄。
bm_device_mem_t sdc_table_input_dev
输入参数。存放对称距离表的 device 空间。
bm_device_mem_t nxcodes_input_dev
输入参数。存放检索向量编码的 device 空间。
bm_device_mem_t nycodes_input_dev
输入参数。存放底库编码的 device 空间。
bm_device_mem_t distance_output_dev
输出参数。存放输出距离的 device 空间。
bm_device_mem_t index_output_dev
输出参数。存放输出排序的 device 空间。
int slice_num
输入参数。原始维度切分数量。
int centroids_num
输入参数。聚类中心的数量。
int database_num
输入参数。数据底库的数量。
int query_num
输入参数。检索向量的数量。
int sort_cnt
输入参数。输出的前 sort_cnt 个最匹配底库向量。
int IP_metric
输入参数。0 表示L2距离计算; 1 表示IP距离计算。
返回值说明:
BM_SUCCESS: 成功
其他:失败
注意事项:
1、输入数据 (查询向量) 和对称距离表的数据类型为 float,底库数据 (底库编码)的数据类型为uint8,存储在设备内存上。
2、输出的排序后的相似度结果的数据类型为 float, 相对应的索引的数据类型为 int,存储在设备内存上。
3、SDC检索过程中 metric 的选择没有区别,因为距离在于输入的sdc table,主要区别在于其 topk 结果是降序,L2的结果为升序。
4、查询向量和数据库向量 L2 距离越小, 表示两者的相似度越高。输出 L2 topk距离按升序排序。
5、查询向量和数据库向量 IP 距离越大, 表示两者的相似度越高。输出 IP topk距离按降序排序。
6、faiss系列算子有多个输入参数,每个参数都有一个使用范围限制,超过该范围的入参对应tpu输出会出错,我们选择了三个主要参数做了测试,固定其中两个维度,测试了第三个维度的最大值,测试结果如下表格所示:
query_num |
vec_dims |
max_database_num |
---|---|---|
1 |
128 |
6500万 |
1 |
256 |
6500万 |
1 |
512 |
6500万 |
64 |
128 |
2500万 |
64 |
256 |
2500万 |
64 |
512 |
1500万 |
256 |
128 |
600万 |
256 |
256 |
600万 |
256 |
512 |
600万 |
database_num |
vec_dims |
max_query_num |
---|---|---|
1000 |
128 |
1000 |
1000 |
256 |
1000 |
1000 |
512 |
1000 |
1万 |
128 |
1000 |
1万 |
256 |
1000 |
1万 |
512 |
1000 |
10万 |
128 |
100 |
10万 |
256 |
50 |
10万 |
512 |
50 |
database_num |
query_num |
max_vec_dims |
---|---|---|
1万 |
1 |
2048 |
1万 |
64 |
512 |
1万 |
128 |
512 |
1万 |
256 |
512 |
10万 |
1 |
2048 |
10万 |
32 |
512 |
10万 |
64 |
512 |
100万 |
1 |
128 |
示例代码
#include "bmcv_api_ext.h" #include "test_misc.h" #include <stdio.h> #include <stdlib.h> #include <time.h> #include <assert.h> #include <sys/time.h> #define BMLIB_SAFE_CALL(cmd) assert(cmd == BM_SUCCESS) int main() { int sort_cnt = 100; int query_num = 1; int slice_m = 32; int ksub = 256; int database_num = 2000000; int input_dtype = 5; // 5: float int output_dtype = 5; int IP_metric = 0; int show_result = 1; struct timespec tp; clock_gettime(CLOCK_REALTIME, &tp); unsigned int seed = tp.tv_nsec; bm_handle_t handle; bm_status_t ret = bm_dev_request(&handle, 0); if (ret != BM_SUCCESS) { printf("request dev failed\n"); return BM_ERR_FAILURE; } srand(seed); int round = 1; fp16 *sdc_table_input_sys_fp16 = (fp16*)malloc(slice_m * ksub * ksub * sizeof(fp16)); float *sdc_table_input_sys_fp32 = (float*)malloc(slice_m * ksub * ksub * sizeof(float)); unsigned char *nxcodes_input_sys = (unsigned char*)malloc(query_num * slice_m); unsigned char *nycodes_input_sys = (unsigned char*)malloc(database_num * slice_m); unsigned char *distance_output_sys = (unsigned char*)malloc(query_num * database_num * dtype_size((data_type_t)output_dtype)); int *index_output_sys = (int*)malloc(query_num * database_num * sizeof(int)); for (int i = 0; i < slice_m; i++) { for (int j = 0; j < ksub; j++) { for (int n = 0; n < ksub; n++) { float value = (n > j) ? (float)rand() / RAND_MAX * 20.0 : 0.0; sdc_table_input_sys_fp32[i * ksub * ksub + j * ksub + n] = value; sdc_table_input_sys_fp16[i * ksub * ksub + j * ksub + n] = fp32tofp16(value, round); } } } for (int i = 0; i < query_num; i++) { for (int j = 0; j < slice_m; j++) { nxcodes_input_sys[i * slice_m + j] = rand() % 256; } } for (int i = 0; i < database_num; i++) { for (int j = 0; j < slice_m; j++) { nycodes_input_sys[i * slice_m + j] = rand() % 256; } } int sdc_table_size = slice_m * ksub * ksub * dtype_size((data_type_t)input_dtype); int nxcodes_size = query_num * slice_m; int nycodes_size = database_num * slice_m; int output_distance_size = query_num * database_num * dtype_size((data_type_t)output_dtype); int output_index_size = query_num * database_num * sizeof(int); bm_device_mem_t sdc_table_input_dev, nxcodes_input_dev, nycodes_input_dev, distance_output_dev, index_output_dev; BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &sdc_table_input_dev, sdc_table_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &nxcodes_input_dev, nxcodes_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &nycodes_input_dev, nycodes_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &distance_output_dev, output_distance_size)); BMLIB_SAFE_CALL(bm_malloc_device_byte(handle, &index_output_dev, output_index_size)); if (input_dtype == DT_FP16) { BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, sdc_table_input_dev, sdc_table_input_sys_fp16)); } else { BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, sdc_table_input_dev, sdc_table_input_sys_fp32)); } BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, nxcodes_input_dev, nxcodes_input_sys)); BMLIB_SAFE_CALL(bm_memcpy_s2d(handle, nycodes_input_dev, nycodes_input_sys)); struct timeval t1, t2; gettimeofday(&t1, NULL); ret = bmcv_faiss_indexPQ_SDC_ext(handle, sdc_table_input_dev, nxcodes_input_dev, nycodes_input_dev, distance_output_dev, index_output_dev, slice_m, ksub, database_num, query_num, sort_cnt, IP_metric, input_dtype, output_dtype); gettimeofday(&t2, NULL); printf("TPU using time(us): %ld(us)\n", (t2.tv_sec - t1.tv_sec) * 1000000 + t2.tv_usec - t1.tv_usec); printf("TPU using time(ms): %ld(ms)\n", ((t2.tv_sec - t1.tv_sec) * 1000000 + t2.tv_usec - t1.tv_usec) / 1000); if(ret != BM_SUCCESS){ bm_free_device(handle, sdc_table_input_dev); bm_free_device(handle, nxcodes_input_dev); bm_free_device(handle, nycodes_input_dev); bm_free_device(handle, distance_output_dev); bm_free_device(handle, index_output_dev); free(sdc_table_input_sys_fp32); free(sdc_table_input_sys_fp16); free(nxcodes_input_sys); free(nycodes_input_sys); free(distance_output_sys); free(index_output_sys); bm_dev_free(handle); return BM_ERR_FAILURE; } BMLIB_SAFE_CALL(bm_memcpy_d2s(handle, distance_output_sys, distance_output_dev)); BMLIB_SAFE_CALL(bm_memcpy_d2s(handle, index_output_sys, index_output_dev)); if (show_result) { printf("SDCsearch result:\n"); for (int i = 0; i < sort_cnt; i++) { printf("top: %d\n", i + 1); printf("index: %d\t", index_output_sys[i]); if (output_dtype == DT_FP16) { printf("distance: %f\n", fp16tofp32(((fp16*)distance_output_sys)[i])); } else { printf("distance: %f\n", ((float*)distance_output_sys)[i]); } } } bm_free_device(handle, sdc_table_input_dev); bm_free_device(handle, nxcodes_input_dev); bm_free_device(handle, nycodes_input_dev); bm_free_device(handle, distance_output_dev); bm_free_device(handle, index_output_dev); free(sdc_table_input_sys_fp32); free(sdc_table_input_sys_fp16); free(nxcodes_input_sys); free(nycodes_input_sys); free(distance_output_sys); free(index_output_sys); bm_dev_free(handle); return 0; }