29 #include "torch/script.h"
34 template <
class T,
class TIndex =
int32_t>
38 : device_type(device_type), device_idx(device_idx) {}
41 neighbors_index = torch::empty(
42 {int64_t(num)}, torch::dtype(ToTorchDtype<TIndex>())
43 .device(device_type, device_idx));
44 *ptr = neighbors_index.data_ptr<TIndex>();
48 neighbors_distance = torch::empty(
49 {int64_t(num)}, torch::dtype(ToTorchDtype<T>())
50 .device(device_type, device_idx));
51 *ptr = neighbors_distance.data_ptr<T>();
55 return neighbors_index.data_ptr<TIndex>();
58 const T*
DistancesPtr()
const {
return neighbors_distance.data_ptr<T>(); }
62 return neighbors_distance;
66 torch::Tensor neighbors_index;
67 torch::Tensor neighbors_distance;
68 torch::DeviceType device_type;
Definition: NeighborSearchAllocator.h:35
const TIndex * IndicesPtr() const
Definition: NeighborSearchAllocator.h:54
const torch::Tensor & NeighborsDistance() const
Definition: NeighborSearchAllocator.h:61
void AllocIndices(TIndex **ptr, size_t num)
Definition: NeighborSearchAllocator.h:40
NeighborSearchAllocator(torch::DeviceType device_type, int device_idx)
Definition: NeighborSearchAllocator.h:37
const T * DistancesPtr() const
Definition: NeighborSearchAllocator.h:58
void AllocDistances(T **ptr, size_t num)
Definition: NeighborSearchAllocator.h:47
const torch::Tensor & NeighborsIndex() const
Definition: NeighborSearchAllocator.h:60