29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
36 namespace invert_neighbors_list_opkernel {
39 class InvertNeighborsListOpKernel :
public tensorflow::OpKernel {
41 explicit InvertNeighborsListOpKernel(
42 tensorflow::OpKernelConstruction* construction)
43 : OpKernel(construction) {}
45 void Compute(tensorflow::OpKernelContext*
context)
override {
46 using namespace tensorflow;
47 static_assert(
sizeof(int64) ==
sizeof(int64_t),
48 "int64 type is not compatible");
50 const Tensor& num_points_tensor =
context->input(0);
52 TensorShapeUtils::IsScalar(num_points_tensor.shape()),
53 errors::InvalidArgument(
54 "num_points must be scalar, got shape ",
55 num_points_tensor.shape().DebugString()));
56 const int64 num_points = num_points_tensor.scalar<int64>()();
58 const Tensor& inp_neighbors_index =
context->input(1);
60 const Tensor& inp_neighbors_row_splits =
context->input(2);
62 const Tensor& inp_neighbors_attributes =
context->input(3);
67 Dim num_neighbors(
"num_neighbors");
77 if (inp_neighbors_attributes.shape().dim_size(0) == 0) {
81 for (
int i = 1; i < inp_neighbors_attributes.shape().dims(); ++i)
82 num_attributes *= inp_neighbors_attributes.shape().dim_size(i);
85 Tensor* neighbors_index = 0;
86 TensorShape neighbors_index_shape(inp_neighbors_index.shape());
88 context->allocate_output(0, neighbors_index_shape,
91 Tensor* neighbors_row_splits = 0;
92 TensorShape neighbors_row_splits_shape({num_points + 1});
94 context->allocate_output(1, neighbors_row_splits_shape,
95 &neighbors_row_splits));
97 Tensor* neighbors_attributes = 0;
98 TensorShape neighbors_attributes_shape(
99 inp_neighbors_attributes.shape());
101 context->allocate_output(2, neighbors_attributes_shape,
102 &neighbors_attributes));
104 Kernel(
context, inp_neighbors_index, inp_neighbors_row_splits,
105 inp_neighbors_attributes, num_attributes, *neighbors_index,
106 *neighbors_row_splits, *neighbors_attributes);
110 virtual void Kernel(tensorflow::OpKernelContext*
context,
111 const tensorflow::Tensor& inp_neighbors_index,
112 const tensorflow::Tensor& inp_neighbors_row_splits,
113 const tensorflow::Tensor& inp_neighbors_attributes,
114 const int num_attributes,
115 tensorflow::Tensor& neighbors_index,
116 tensorflow::Tensor& neighbors_row_splits,
117 tensorflow::Tensor& neighbors_attributes) = 0;
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor,...)
Definition: TorchHelper.h:244
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
ImGuiContext * context
Definition: Window.cpp:95
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
Definition: ShapeChecking.h:35