Open3D (C++ API)  0.15.1
InvertNeighborsListOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 
29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
35 // namespace for code that is common for all kernels
36 namespace invert_neighbors_list_opkernel {
37 
38 // Base class with common code for the OpKernel implementations
39 class InvertNeighborsListOpKernel : public tensorflow::OpKernel {
40 public:
41  explicit InvertNeighborsListOpKernel(
42  tensorflow::OpKernelConstruction* construction)
43  : OpKernel(construction) {}
44 
45  void Compute(tensorflow::OpKernelContext* context) override {
46  using namespace tensorflow;
47  static_assert(sizeof(int64) == sizeof(int64_t),
48  "int64 type is not compatible");
49 
50  const Tensor& num_points_tensor = context->input(0);
51  OP_REQUIRES(context,
52  TensorShapeUtils::IsScalar(num_points_tensor.shape()),
53  errors::InvalidArgument(
54  "num_points must be scalar, got shape ",
55  num_points_tensor.shape().DebugString()));
56  const int64 num_points = num_points_tensor.scalar<int64>()();
57 
58  const Tensor& inp_neighbors_index = context->input(1);
59 
60  const Tensor& inp_neighbors_row_splits = context->input(2);
61 
62  const Tensor& inp_neighbors_attributes = context->input(3);
63 
64  // check input shapes
65  {
66  using namespace open3d::ml::op_util;
67  Dim num_neighbors("num_neighbors");
68 
69  CHECK_SHAPE(context, inp_neighbors_index, num_neighbors);
70  CHECK_SHAPE_IGNORE_LAST_DIMS(context, inp_neighbors_attributes,
71  num_neighbors || 0);
72  CHECK_SHAPE(context, inp_neighbors_row_splits, Dim());
73  }
74 
75  // compute the number of attributes for each neighbor
76  int num_attributes;
77  if (inp_neighbors_attributes.shape().dim_size(0) == 0) {
78  num_attributes = 0;
79  } else {
80  num_attributes = 1;
81  for (int i = 1; i < inp_neighbors_attributes.shape().dims(); ++i)
82  num_attributes *= inp_neighbors_attributes.shape().dim_size(i);
83  }
84 
85  Tensor* neighbors_index = 0;
86  TensorShape neighbors_index_shape(inp_neighbors_index.shape());
87  OP_REQUIRES_OK(context,
88  context->allocate_output(0, neighbors_index_shape,
89  &neighbors_index));
90 
91  Tensor* neighbors_row_splits = 0;
92  TensorShape neighbors_row_splits_shape({num_points + 1});
93  OP_REQUIRES_OK(context,
94  context->allocate_output(1, neighbors_row_splits_shape,
95  &neighbors_row_splits));
96 
97  Tensor* neighbors_attributes = 0;
98  TensorShape neighbors_attributes_shape(
99  inp_neighbors_attributes.shape());
100  OP_REQUIRES_OK(context,
101  context->allocate_output(2, neighbors_attributes_shape,
102  &neighbors_attributes));
103 
104  Kernel(context, inp_neighbors_index, inp_neighbors_row_splits,
105  inp_neighbors_attributes, num_attributes, *neighbors_index,
106  *neighbors_row_splits, *neighbors_attributes);
107  }
108 
109  // Function with the device specific code
110  virtual void Kernel(tensorflow::OpKernelContext* context,
111  const tensorflow::Tensor& inp_neighbors_index,
112  const tensorflow::Tensor& inp_neighbors_row_splits,
113  const tensorflow::Tensor& inp_neighbors_attributes,
114  const int num_attributes,
115  tensorflow::Tensor& neighbors_index,
116  tensorflow::Tensor& neighbors_row_splits,
117  tensorflow::Tensor& neighbors_attributes) = 0;
118 
119 private:
120 };
121 
122 } // namespace invert_neighbors_list_opkernel
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor,...)
Definition: TorchHelper.h:244
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
ImGuiContext * context
Definition: Window.cpp:95
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
Definition: ShapeChecking.h:35