bmcv_faiss_indexflatIP

计算查询向量与数据库向量的内积距离, 输出前 K (sort_cnt) 个最匹配的内积距离值及其对应的索引。

接口形式:

bm_status_t bmcv_faiss_indexflatIP(
        bm_handle_t     handle,
        bm_device_mem_t input_data_global_addr,
        bm_device_mem_t db_data_global_addr,
        bm_device_mem_t buffer_global_addr,
        bm_device_mem_t output_sorted_similarity_global_addr,
        bm_device_mem_t output_sorted_index_global_addr,
        int             vec_dims,
        int             query_vecs_num,
        int             database_vecs_num,
        int             sort_cnt,
        int             is_transpose,
        int             input_dtype,
        int             output_dtype);

输入参数说明:

  • bm_handle_t handle

    输入参数。bm_handle 句柄。

  • bm_device_mem_t input_data_global_addr

    输入参数。存放查询向量组成的矩阵的 device 空间。

  • bm_device_mem_t db_data_global_addr

    输入参数。存放底库向量组成的矩阵的 device 空间。

  • bm_device_mem_t buffer_global_addr

    输入参数。存放计算出的内积值的缓存空间。

  • bm_device_mem_t output_sorted_similarity_global_addr

    输出参数。存放排序后的最匹配的内积值的 device 空间。

  • bm_device_mem_t output_sorted_index_global_add

    输出参数。存储输出内积值对应索引的 device 空间。

  • int vec_dims

    输入参数。向量维数。

  • int query_vecs_num

    输入参数。查询向量的个数。

  • int database_vecs_num

    输入参数。底库向量的个数。

  • int sort_cnt

    输入参数。输出的前 sort_cnt 个最匹配的内积值。

  • int is_transpose

    输入参数。0 表示底库矩阵不转置; 1 表示底库矩阵转置。

  • int input_dtype

    输入参数。输入数据类型,支持 float 和 char, 5 表示float, 1 表示char。

  • int output_dtype

    输出参数。输出数据类型,支持 float 和 int, 5 表示float, 9 表示int。

返回值说明:

  • BM_SUCCESS: 成功

  • 其他:失败

注意事项:

1、输入数据(查询向量)和底库数据(底库向量)的数据类型为 float 或 char。

2、输出的排序后的相似度的数据类型为 float 或 int, 相对应的索引的数据类型为 int。

3、底库数据通常以 database_vecs_num * vec_dims 的形式排布在内存中。此时, 参数 is_transpose 需要设置为 1。

4、查询向量和数据库向量内积距离值越大, 表示两者的相似度越高。因此, 在 TopK 过程中对内积距离值按降序排序。

5、该接口用于 Faiss::IndexFlatIP.search(), 在 BM1684X 上实现。考虑 BM1684X 上 TPU 的连续内存, 针对 100W 底库, 可以在单芯片上一次查询最多约 512 个 256 维的输入。

示例代码

int sort_cnt = 100;
int vec_dims = 256;
int query_vecs_num = 1;
int database_vecs_num = 2000000;
int is_transpose = 1;
int input_dtype = 5; // 5: float
int output_dtype = 5;

float *input_data = new float[query_vecs_num * vec_dims];
float *db_data = new float[database_vecs_num * vec_dims];

void matrix_gen_data(float* data, u32 len) {
    for (u32 i = 0; i < len; i++) {
        data[i] = ((float)rand() / (float)RAND_MAX) * 3.3;
    }
}

matrix_gen_data(input_data, query_vecs_num * vec_dims);
matrix_gen_data(db_data, vec_dims * database_vecs_num);

bm_handle_t handle = nullptr;
bm_dev_request(&handle, 0);
bm_device_mem_t query_data_dev_mem;
bm_device_mem_t db_data_dev_mem;
bm_malloc_device_byte(handle, &query_data_dev_mem,
        query_vecs_num * vec_dims * sizeof(float));
bm_malloc_device_byte(handle, &db_data_dev_mem,
        database_vecs_num * vec_dims * sizeof(float));
bm_memcpy_s2d(handle, query_data_dev_mem, input_data);
bm_memcpy_s2d(handle, db_data_dev_mem, db_data);

float *output_dis = new float[query_vecs_num * sort_cnt];
int *output_inx = new int[query_vecs_num * sort_cnt];
bm_device_mem_t buffer_dev_mem;
bm_device_mem_t sorted_similarity_dev_mem;
bm_device_mem_t sorted_index_dev_mem;
bm_malloc_device_byte(handle, &buffer_dev_mem,
        query_vecs_num * database_vecs_num * sizeof(float));
bm_malloc_device_byte(handle, &sorted_similarity_dev_mem,
        query_vecs_num * sort_cnt * sizeof(float));
bm_malloc_device_byte(handle, &sorted_index_dev_mem,
        query_vecs_num * sort_cnt * sizeof(int));

bmcv_faiss_indexflatIP(handle,
                       query_data_dev_mem,
                       db_data_dev_mem,
                       buffer_dev_mem,
                       sorted_similarity_dev_mem,
                       sorted_index_dev_mem,
                       vec_dims,
                       query_vecs_num,
                       database_vecs_num,
                       sort_cnt,
                       is_transpose,
                       input_dtype,
                       output_dtype);
bm_memcpy_d2s(handle, output_dis, sorted_similarity_dev_mem);
bm_memcpy_d2s(handle, output_inx, sorted_index_dev_mem);
delete[] input_data;
delete[] db_data;
delete[] output_similarity;
delete[] output_index;
bm_free_device(handle, query_data_dev_mem);
bm_free_device(handle, db_data_dev_mem);
bm_free_device(handle, buffer_dev_mem);
bm_free_device(handle, sorted_similarity_dev_mem);
bm_free_device(handle, sorted_index_dev_mem);
bm_dev_free(handle);