29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
37 : OpKernel(construction) {
38 using namespace tensorflow;
40 OP_REQUIRES_OK(construction,
41 construction->GetAttr(
"nsample", &
nsample));
42 OP_REQUIRES_OK(construction, construction->GetAttr(
"radius", &
radius));
45 errors::InvalidArgument(
"BallQuery expects positive nsample"));
49 using namespace tensorflow;
51 const Tensor& inp_tensor =
context->input(0);
54 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
55 errors::InvalidArgument(
"BallQuery expects "
56 "(batch_size,num_points,3) inp shape"));
57 int batch_size = inp_tensor.shape().dim_size(0);
58 int pts_size = inp_tensor.shape().dim_size(1);
59 auto inp_flat = inp_tensor.flat<
float>();
60 const float* inp = &(inp_flat(0));
62 const Tensor& center_tensor =
context->input(1);
64 center_tensor.dims() == 3 &&
65 center_tensor.shape().dim_size(2) == 3,
66 errors::InvalidArgument(
68 "(batch_size,num_points,3) center shape"));
69 int ball_size = center_tensor.shape().dim_size(1);
70 auto center_flat = center_tensor.flat<
float>();
71 const float* center = &(center_flat(0));
76 0, TensorShape{batch_size, ball_size, nsample},
78 auto out_flat = out_tensor->flat<
int>();
79 int* out = &(out_flat(0));
ImGuiContext * context
Definition: Window.cpp:95
Definition: BallQueryOpKernel.h:34
int nsample
Definition: BallQueryOpKernel.h:96
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:36
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:48
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition: BallQueryOpKernel.h:97