Open3D (C++ API)  0.17.0
NmsOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10//#include "open3d/ml/impl/misc/VoxelPooling.h"
12#include "tensorflow/core/framework/op.h"
13#include "tensorflow/core/framework/op_kernel.h"
14#include "tensorflow/core/lib/core/errors.h"
15
17// namespace for code that is common for all kernels
18namespace nms_opkernel {
19
20class OutputAllocator {
21public:
22 OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
23
24 void AllocKeepIndices(int64_t** ptr, int64_t num) {
25 using namespace tensorflow;
26 *ptr = nullptr;
27 Tensor* tensor = 0;
28 TensorShape shape({num});
29 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
30 auto flat_tensor = tensor->flat<int64>();
31 *ptr = (int64_t*)flat_tensor.data();
32 }
33
34private:
35 tensorflow::OpKernelContext* context;
36};
37
38// Base class with common code for the OpKernel implementations
39class NmsOpKernel : public tensorflow::OpKernel {
40public:
41 explicit NmsOpKernel(tensorflow::OpKernelConstruction* construction)
42 : OpKernel(construction) {
43 OP_REQUIRES_OK(construction,
44 construction->GetAttr("nms_overlap_thresh",
45 &nms_overlap_thresh));
46 }
47
48 void Compute(tensorflow::OpKernelContext* context) override {
49 using namespace tensorflow;
50 const Tensor& boxes = context->input(0);
51 const Tensor& scores = context->input(1);
52
53 {
54 using namespace open3d::ml::op_util;
55 Dim num_points("num_points");
56 Dim five(5, "five");
57 CHECK_SHAPE(context, boxes, num_points, five);
58 CHECK_SHAPE(context, scores, num_points);
59 }
60
61 Kernel(context, boxes, scores);
62 }
63
64 // Function with the device specific code
65 virtual void Kernel(tensorflow::OpKernelContext* context,
66 const tensorflow::Tensor& boxes,
67 const tensorflow::Tensor& scores) = 0;
68
69protected:
70 float nms_overlap_thresh;
71};
72
73} // namespace nms_opkernel
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:186
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16