HAU 操作

tpu_hau_sort

按升序或降序排序前 K 个最小或最大的数。

tpu_hau_sort(system_addr_t output_addr, system_addr_t input_addr, int len, int K, bool descended, data_type_t dtype)
\[\mathsf{output(k) = input(i_k)}\]

如果升序,则

\[\mathsf{input(i_0)\leq input(i_1)\leq\cdots\leq input(i_{K - 1})\leq\cdots\leq input(i_{len - 1})}\]

如果降序,则

\[\mathsf{input(i_0)\geq input(i_1)\geq\cdots\geq input(i_{K - 1})\geq\cdots\geq input(i_{len - 1})}\]

其中,\(\mathsf{i_0, i_1, \ldots, i_{len - 1}}\) 互不相同,是 \(\mathsf{0, 1, \ldots, len - 1}\) 的重排。

参数
  • output_addr – output 的地址

  • input_addr – input 的地址

  • len – input 的长度

  • K – 排序长度

  • descended – 降序的标志

  • dtype – output 和 input 的元素的数据类型

注意事项

  • output_addrinput_addr 都被 4 整除。

  • dtype 的有效取值是 DT_FP32DT_INT32DT_UINT32

  • output 的长度是 len,前 K 个数是排序后的结果,K 小于等于 len

tpu_hau_sort_natural_index

按升序或降序稳定排序前 K 个最小或最大的数,并输出排序后的索引,排序前的索引是自然索引。

tpu_hau_sort_natural_index(system_addr_t output_data_addr, system_addr_t output_idx_addr, system_addr_t input_addr, int len, int K, bool descended, data_type_t dtype)
\[\mathsf{output\_data(k) = input(i_k)~~~~output\_idx(k) = i_k}\]
\[\mathsf{\text{如果}~input(i_{k}) = input(i_{k + 1})\text{,则}~i_{k}<i_{k + 1}}\]

如果升序,则

\[\mathsf{input(i_0)\leq input(i_1)\leq\cdots\leq input(i_{K - 1})\leq\cdots\leq input(i_{len - 1})}\]

如果降序,则

\[\mathsf{input(i_0)\geq input(i_1)\geq\cdots\geq input(i_{K - 1})\geq\cdots\geq input(i_{len - 1})}\]

其中,\(\mathsf{i_0, i_1, \ldots, i_{len - 1}}\) 互不相同,是 \(\mathsf{0, 1, \ldots, len - 1}\) 的重排。

参数
  • output_data_addr – output_data 的地址

  • output_idx_addr – output_idx 的地址

  • input_addr – input 的地址

  • len – input 的长度

  • K – 排序长度

  • descended – 降序的标志

  • dtype – output_data 和 input 的元素的数据类型

注意事项

  • output_data_addroutput_idx_addrinput_addr 都被 4 整除。

  • dtype 的有效取值是 DT_FP32DT_INT32DT_UINT32, output_idx 的元素的数据类型是 INT32。

  • output_data 和 output_idx 的长度是 len,前 K 个数是排序后的结果和对应的索引,K 小于等于 len

tpu_hau_sort_specific_index

按升序或降序稳定排序前 K 个最小或最大的数,并输出排序后的索引,排序前的索引是指定索引。

tpu_hau_sort_specific_index(system_addr_t output_data_addr, system_addr_t output_idx_addr, system_addr_t input_data_addr, system_addr_t input_idx_addr, int len, int K, bool descended, data_type_t dtype)
\[\mathsf{output\_data(k) = input\_data(i_k)~~~~output\_idx(k) = input\_idx(i_k)}\]
\[\mathsf{\text{如果}~input\_data(i_{k}) = input\_data(i_{k + 1})\text{,则}~input\_idx(i_{k})\leq input\_idx(i_{k + 1})}\]

如果升序,则

\[\mathsf{input\_data(i_0)\leq input\_data(i_1)\leq\cdots\leq input\_data(i_{K - 1})\leq\cdots\leq input\_data(i_{len - 1})}\]

如果降序,则

\[\mathsf{input\_data(i_0)\geq input\_data(i_1)\geq\cdots\geq input\_data(i_{K - 1})\geq\cdots\geq input\_data(i_{len - 1})}\]

其中,\(\mathsf{i_0, i_1, \ldots, i_{len - 1}}\) 互不相同,是 \(\mathsf{0, 1, \ldots, len - 1}\) 的重排。

参数
  • output_data_addr – output_data 的地址

  • output_idx_addr – output_idx 的地址

  • input_data_addr – input_data 的地址

  • input_idx_addr – input_idx 的地址

  • len – input_data 和 input_idx 的长度

  • K – 排序长度

  • descended – 降序的标志

  • dtype – output_data 和 input 的元素的数据类型

注意事项

  • output_data_addroutput_idx_addrinput_data_addrinput_idx_addr 都被 4 整除。

  • dtype 的有效取值是 DT_FP32DT_INT32DT_UINT32, output_idx 和 input_idx 的元素的数据类型是 INT32。

  • output_data 和 output_idx 的长度是 len,前 K 个数是排序后的结果和对应的索引,K 小于等于 len

tpu_hau_line_gather

通过 line 的索引取值得到输出张量,即 output = param[index]。

tpu_hau_line_gather(system_addr_t output_addr, system_addr_t param_addr, system_addr_t index_addr, scalar_t C, int line_num, int line_len, int index_len, int start, int end, data_type_t dtype, bool fill_const)
\[\begin{split}\mathsf{output(h, w)} = {\begin{cases} \mathsf{param(index(h) - start, w)}&\mathsf{\text{如果}~index(h)~\text{是有效索引}}\\ \mathsf{C}&\mathsf{\text{如果}~index(h)~\text{是无效索引,}fill\_const~\text{是}~true}\end{cases}}\end{split}\]
参数
  • output_addr – output 的地址

  • param_addr – param 的地址

  • index_addr – index 的地址

  • C – 常数

  • line_num – param 的 line 的数量

  • line_len – param 的 line 的长度

  • index_len – index 的长度

  • start – 有效索引的起始值

  • end – 有效索引的结束值

  • dtype – output 和 param 的元素的数据类型

  • fill_const – output 在无效索引处填 C 的标志

注意事项

  • output_addrparam_addrindex_addr 都被 64 整除。

  • output 的 shape 是 [index_len, line_len], param 的 shape 是 [line_num, line_len], index 的 shape 是 [index_len],都是 continuous layout

  • index 的元素的数据类型是 UINT32,有效索引的范围是 [start, end], start 大于等于 0,start 小于等于 endend 小于 line_num

  • 如果索引无效,fill_const 是 false,则 output 的对应元素不会被填。