29 #include "../TensorFlowHelper.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/lib/core/errors.h"
37 namespace radius_search_opkernel {
39 class RadiusSearchOpKernel :
public tensorflow::OpKernel {
41 explicit RadiusSearchOpKernel(
42 tensorflow::OpKernelConstruction* construction)
43 : OpKernel(construction) {
45 using namespace tensorflow;
46 std::string metric_str;
47 OP_REQUIRES_OK(construction,
48 construction->GetAttr(
"metric", &metric_str));
49 if (metric_str ==
"L1")
54 OP_REQUIRES_OK(construction,
55 construction->GetAttr(
"ignore_query_point",
56 &ignore_query_point));
58 OP_REQUIRES_OK(construction, construction->GetAttr(
"return_distances",
60 OP_REQUIRES_OK(construction,
61 construction->GetAttr(
"normalize_distances",
62 &normalize_distances));
65 void Compute(tensorflow::OpKernelContext*
context)
override {
66 using namespace tensorflow;
67 static_assert(
sizeof(int64) ==
sizeof(int64_t),
68 "int64 type is not compatible");
71 const Tensor& queries =
context->input(1);
72 const Tensor& radii =
context->input(2);
73 const Tensor& points_row_splits =
context->input(3);
74 const Tensor& queries_row_splits =
context->input(4);
78 Dim num_points(
"num_points");
79 Dim num_queries(
"num_queries");
80 Dim batch_size(
"batch_size");
88 Tensor* query_neighbors_row_splits = 0;
89 TensorShape query_neighbors_row_splits_shape(
90 {queries.shape().dim_size(0) + 1});
92 1, query_neighbors_row_splits_shape,
93 &query_neighbors_row_splits));
96 queries_row_splits, *query_neighbors_row_splits);
99 virtual void Kernel(tensorflow::OpKernelContext*
context,
100 const tensorflow::Tensor&
points,
101 const tensorflow::Tensor& queries,
102 const tensorflow::Tensor& radius,
103 const tensorflow::Tensor& points_row_splits,
104 const tensorflow::Tensor& queries_row_splits,
105 tensorflow::Tensor& query_neighbors_row_splits) = 0;
109 bool ignore_query_point;
110 bool return_distances;
111 bool normalize_distances;
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
ImGuiContext * context
Definition: Window.cpp:95
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
Definition: FixedRadiusIndex.cpp:35
Metric
Supported metrics.
Definition: NeighborSearchCommon.h:38
@ L1
Definition: NeighborSearchCommon.h:38
@ L2
Definition: NeighborSearchCommon.h:38
Definition: ShapeChecking.h:35