29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
37 tensorflow::OpKernelConstruction*
context)
38 : tensorflow::OpKernel(
context) {
39 using namespace tensorflow;
43 errors::InvalidArgument(
44 "TrilinearDevoxelize expects positive resolution"));
48 using namespace tensorflow;
49 const Tensor& coords =
context->input(0);
51 context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
52 errors::InvalidArgument(
"TrilinearDevoxelize expects "
53 "(batch_size, 3, N) coordinate shape"));
54 const Tensor& feat =
context->input(1);
55 OP_REQUIRES(
context, feat.dims() == 5,
56 errors::InvalidArgument(
"TrilinearDevoxelize expects "
57 "5 dimensions for features"));
59 int batch_size = coords.shape().dim_size(0);
60 int num_points = coords.shape().dim_size(2);
61 int feat_dim = feat.shape().dim_size(1);
63 auto coords_flat = coords.flat<
float>();
64 auto feat_flat = feat.flat<
float>();
66 const float* inp_coords = &(coords_flat(0));
67 const float* inp_feat = &(feat_flat(0));
72 0, TensorShape{batch_size, feat_dim, num_points},
77 1, TensorShape{batch_size, 8, num_points},
82 2, TensorShape{batch_size, 8, num_points},
84 auto flat_0 = out_tensor_0->flat<
float>();
85 auto flat_1 = out_tensor_1->flat<
int>();
86 auto flat_2 = out_tensor_2->flat<
float>();
88 float* out_0 = &(flat_0(0));
89 int* out_1 = &(flat_1(0));
90 float* out_2 = &(flat_2(0));
94 r *
r *
r,
true, inp_coords, inp_feat, out_1, out_2, out_0);
97 r *
r *
r,
false, inp_coords, inp_feat, out_1, out_2, out_0);
122 tensorflow::OpKernelConstruction*
context)
123 : tensorflow::OpKernel(
context) {
124 using namespace tensorflow;
128 errors::InvalidArgument(
129 "TrilinearDevoxelizeGrad expects positive resolution"));
133 using namespace tensorflow;
134 const Tensor& grad_y =
context->input(0);
137 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
138 "(batch_size, C, N) gradient shape"));
139 const Tensor& inds =
context->input(1);
141 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
142 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
143 "(batch_size, 8, N) indices shape"));
144 const Tensor& wgts =
context->input(2);
146 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
147 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
148 "(batch_size, 8, N) weights shape"));
150 int batch_size = grad_y.shape().dim_size(0);
151 int num_points = grad_y.shape().dim_size(2);
152 int feat_dim = grad_y.shape().dim_size(1);
154 auto grad_y_flat = grad_y.flat<
float>();
155 auto inds_flat = inds.flat<
int>();
156 auto wgts_flat = wgts.flat<
float>();
158 const float* inp_grad_y = &(grad_y_flat(0));
159 const int* inp_inds = &(inds_flat(0));
160 const float* inp_wgts = &(wgts_flat(0));
165 0, TensorShape{batch_size, feat_dim, r, r, r},
167 auto flat_tensor = out_tensor->flat<
float>();
169 float* out = &(flat_tensor(0));
172 inp_wgts, inp_grad_y, out);
ImGuiContext * context
Definition: Window.cpp:95
Definition: TrilinearDevoxelizeKernel.h:119
int r
Definition: TrilinearDevoxelizeKernel.h:186
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:132
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:121
Definition: TrilinearDevoxelizeKernel.h:34
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:36
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:47
int r
Definition: TrilinearDevoxelizeKernel.h:115
bool is_training
Definition: TrilinearDevoxelizeKernel.h:116