@@ -274,6 +274,264 @@ class CheckDatasetConsistencyOp : public Operator<CPUContext> {
274274 TreeIterator iterator_;
275275};
276276
277+ /* *
278+ * Simple wrapper class allowing an easy traversal of the tensors representing
279+ * the hirerarchical structure.
280+ */
281+ class TreeWalker {
282+ public:
283+ TreeWalker (const vector<const Blob*>& inputs, TreeCursor& cursor)
284+ : inputs_(inputs), cursor_(cursor), sizes_(cursor.it.numOffsetFields()) {
285+ if (cursor.offsets .empty ()) {
286+ cursor.offsets .assign (cursor.it .numOffsetFields (), 0 );
287+ }
288+
289+ for (int fieldId = 0 ; fieldId < cursor_.it .fields ().size (); ++fieldId) {
290+ fields_.emplace_back (*this , fieldId);
291+ }
292+
293+ gatherLengthData ();
294+
295+ gatherSizeLimits ();
296+
297+ // The invariant we hold is that we are always one step ahead
298+ advance ();
299+ }
300+
301+ // Returns the number of records in a dataset
302+ inline TOffset size () const {
303+ return limits_.at (0 );
304+ }
305+
306+ void advance () {
307+ prevOffsets_ = cursor_.offsets ;
308+ cursor_.it .advance (lengths_, cursor_.offsets , sizes_, limits_, 1 );
309+ }
310+
311+ private:
312+ inline const TensorCPU& input (int32_t idx) const {
313+ return inputs_[idx]->Get <TensorCPU>();
314+ }
315+
316+ // TODO: Change to fieldDesc
317+ inline const TreeIterator::FieldDesc& field (int idx) const {
318+ return cursor_.it .fields ().at (idx);
319+ }
320+
321+ inline int lengthIdx (int fieldId) const {
322+ return field (fieldId).lengthFieldId + 1 ;
323+ }
324+
325+ inline TOffset offset (int fieldId) const {
326+ return prevOffsets_[lengthIdx (fieldId)];
327+ }
328+
329+ std::vector<TIndex> fieldDim (int fieldId) const {
330+ auto tensorDim = input (fieldId).dims ();
331+ tensorDim[0 ] = sizes_[lengthIdx (fieldId)];
332+ return tensorDim;
333+ }
334+
335+ void * fieldPtr (int fieldId) const {
336+ auto & in = input (fieldId);
337+ return (char *)in.raw_data () +
338+ offset (fieldId) * in.size_from_dim (1 ) * in.meta ().itemsize ();
339+ }
340+
341+ public:
342+ // Simple Proxy class to expose nicer API for field access
343+ class Field {
344+ public:
345+ Field (TreeWalker& walker, int fieldId)
346+ : walker_(walker), fieldId_(fieldId) {}
347+
348+ inline std::vector<TIndex> dim () const {
349+ return walker_.fieldDim (fieldId_);
350+ }
351+
352+ inline const TypeMeta& meta () const {
353+ return walker_.input (fieldId_).meta ();
354+ }
355+
356+ inline void * ptr () const {
357+ return walker_.fieldPtr (fieldId_);
358+ }
359+
360+ private:
361+ const TreeWalker& walker_;
362+ const int fieldId_;
363+ };
364+
365+ // Notice that a reference is returned. If advance() is called the fields will
366+ // be updated to represent the new state.
367+ inline const std::vector<Field>& fields () const {
368+ return fields_;
369+ }
370+
371+ private:
372+ void gatherLengthData () {
373+ static const TLength lenZero = 0 ;
374+ lengths_.resize (cursor_.it .numLengthFields ());
375+ for (int i = 0 ; i < lengths_.size (); ++i) {
376+ auto & in = input (cursor_.it .lengthField (i).id );
377+ if (in.size () > 0 ) {
378+ lengths_[i] = in.data <int >();
379+ } else {
380+ lengths_[i] = &lenZero;
381+ }
382+ }
383+ }
384+
385+ void gatherSizeLimits () {
386+ limits_.assign (sizes_.size (), std::numeric_limits<TOffset>::max ());
387+ for (auto fieldId = 0 ; fieldId < cursor_.it .fields ().size (); ++fieldId) {
388+ auto lengthFieldIdx = lengthIdx (fieldId);
389+ limits_[lengthFieldIdx] =
390+ std::min (limits_[lengthFieldIdx], (TOffset)input (fieldId).dims ()[0 ]);
391+ }
392+ }
393+
394+ const vector<const Blob*>& inputs_;
395+ TreeCursor& cursor_;
396+ std::vector<Field> fields_;
397+
398+ std::vector<const TLength*> lengths_;
399+ std::vector<TOffset> limits_;
400+ std::vector<TOffset> sizes_;
401+ std::vector<TOffset> offsets_;
402+ std::vector<TOffset> prevOffsets_;
403+ };
404+
405+ using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
406+
407+ class PackRecordsOp : public Operator <CPUContext> {
408+ public:
409+ PackRecordsOp (const OperatorDef& operator_def, Workspace* ws)
410+ : Operator(operator_def, ws),
411+ fields_ (OperatorBase::GetRepeatedArgument<std::string>(" fields" )) {}
412+
413+ bool RunOnDevice () override {
414+ // There should be one input per field
415+ CAFFE_ENFORCE_EQ (InputSize (), fields_.size ());
416+ CAFFE_ENFORCE_EQ (OutputSize (), 1 );
417+
418+ TreeCursor cursor ((TreeIterator (fields_)));
419+
420+ TreeWalker walker (Inputs (), cursor);
421+
422+ Output (0 )->Resize (walker.size ());
423+
424+ // Output(0)->raw_mutable_data(TypeMeta::Make<SharedTensorVectorPtr>()));
425+ auto * dst = Output (0 )->mutable_data <SharedTensorVectorPtr>();
426+
427+ for (int batchId = 0 ; batchId < walker.size (); ++batchId) {
428+ dst[batchId] = std::make_shared<std::vector<TensorCPU>>();
429+ dst[batchId]->reserve (walker.fields ().size ());
430+
431+ for (const auto & field : walker.fields ()) {
432+ dst[batchId]->emplace_back (field.dim ());
433+ auto & tensor = dst[batchId]->back ();
434+ context_.template CopyItems <CPUContext, CPUContext>(
435+ field.meta (),
436+ tensor.size (),
437+ field.ptr () /* src */ ,
438+ tensor.raw_mutable_data (field.meta ()) /* dst */ );
439+ }
440+
441+ walker.advance ();
442+ }
443+
444+ return true ;
445+ }
446+
447+ private:
448+ std::vector<std::string> fields_;
449+ };
450+
451+ class UnPackRecordsOp : public Operator <CPUContext> {
452+ public:
453+ UnPackRecordsOp (const OperatorDef& operator_def, Workspace* ws)
454+ : Operator(operator_def, ws),
455+ fields_ (OperatorBase::GetRepeatedArgument<std::string>(" fields" )) {}
456+
457+ bool RunOnDevice () override {
458+ const auto * inputs = Input (0 ).template data <SharedTensorVectorPtr>();
459+ const auto numRows = Input (0 ).size ();
460+
461+ CAFFE_ENFORCE_GE (numRows, 0 );
462+
463+ if (numRows == 0 ) {
464+ return true ;
465+ }
466+
467+ const auto & inputZero = inputs[0 ];
468+ CAFFE_ENFORCE (inputZero);
469+
470+ const auto numTensors = inputZero->size ();
471+
472+ CAFFE_ENFORCE_EQ (numTensors, fields_.size ());
473+ CAFFE_ENFORCE_EQ (numTensors, OutputSize ());
474+
475+ // Precomputer the output sizes to avoid resizing
476+ std::vector<std::vector<TIndex>> outputDims (numTensors);
477+
478+ for (int i = 0 ; i < numTensors; ++i) {
479+ outputDims[i] = inputs[0 ]->at (i).dims ();
480+ outputDims[i][0 ] = 0 ;
481+ }
482+
483+ for (int i = 0 ; i < numRows; ++i) {
484+ CAFFE_ENFORCE (inputs[i]);
485+ for (int j = 0 ; j < inputs[i]->size (); ++j) {
486+ const auto & input = inputs[i]->at (j);
487+ const auto & inputZeroTensor = inputZero->at (j);
488+
489+ // Checks to ensure that dimensions/sizes match
490+ CAFFE_ENFORCE_EQ (inputZeroTensor.ndim (), input.ndim ());
491+ CAFFE_ENFORCE (inputZeroTensor.meta () == input.meta ());
492+ // We look from first dimension, because we concat on the first.
493+ for (int k = 1 ; k < input.ndim (); ++k) {
494+ CAFFE_ENFORCE_EQ (input.dims ()[k], inputZeroTensor.dims ()[k]);
495+ }
496+
497+ outputDims[j][0 ] += input.dim (0 );
498+ }
499+ }
500+
501+ // Resize to the final output size
502+ std::vector<void *> destinations (numTensors);
503+ for (int i = 0 ; i < numTensors; ++i) {
504+ Output (i)->Resize (outputDims[i]);
505+ destinations[i] = Output (i)->raw_mutable_data (inputZero->at (i).meta ());
506+ }
507+
508+ for (int i = 0 ; i < numRows; ++i) {
509+ for (int j = 0 ; j < numTensors; ++j) {
510+ const auto & input = inputs[i]->at (j);
511+ // Skip empty tensors
512+ if (input.size () == 0 ) {
513+ continue ;
514+ }
515+
516+ context_.CopyItems <CPUContext, CPUContext>(
517+ inputZero->at (j).meta (),
518+ input.size (),
519+ input.raw_data () /* src */ ,
520+ destinations[j] /* dst */
521+ );
522+
523+ destinations[j] =
524+ (char *)destinations[j] + input.size () * inputZero->at (j).itemsize ();
525+ }
526+ }
527+
528+ return true ;
529+ }
530+
531+ private:
532+ std::vector<std::string> fields_;
533+ };
534+
277535class ReadNextBatchOp : public Operator <CPUContext> {
278536 public:
279537 ReadNextBatchOp (const OperatorDef& operator_def, Workspace* ws)
@@ -803,6 +1061,8 @@ REGISTER_CPU_OPERATOR(CreateTensorVector, CreateTensorVectorOp<CPUContext>);
8031061REGISTER_CPU_OPERATOR (TensorVectorSize, TensorVectorSizeOp<CPUContext>);
8041062REGISTER_CPU_OPERATOR (ConcatTensorVector, ConcatTensorVectorOp<CPUContext>);
8051063REGISTER_CPU_OPERATOR (CollectTensor, CollectTensorOp<CPUContext>);
1064+ REGISTER_CPU_OPERATOR (PackRecords, PackRecordsOp);
1065+ REGISTER_CPU_OPERATOR (UnPackRecords, UnPackRecordsOp);
8061066
8071067OPERATOR_SCHEMA (CreateTreeCursor)
8081068 .NumInputs(0 )
@@ -1048,6 +1308,34 @@ output vectors.
10481308)DOC" )
10491309 .Arg(" num_to_collect" , " The max number of tensors to collect" );
10501310
1311+ OPERATOR_SCHEMA (PackRecords)
1312+ .NumInputs(1 , INT_MAX)
1313+ .NumOutputs(1 )
1314+ .SetDoc(R"DOC(
1315+ Given a dataset under a schema specified by the `fields` argument will pack all the input tensors into one,
1316+ where each tensor element represents a row of data (batch of size 1). This format allows easier use with the rest of Caffe2 operators.
1317+ )DOC" )
1318+ .Arg(
1319+ " fields" ,
1320+ " List of strings representing the string names in the format"
1321+ " specified in the doc for CreateTreeCursor." )
1322+ .Output(
1323+ 0 ,
1324+ " tensor" ,
1325+ " One dimensional tensor having a complex type of SharedTensorVectorPtr. In order to reverse it back to the original input it has to be inserted into UnPackRecordsOp." );
1326+
1327+ OPERATOR_SCHEMA (UnPackRecords)
1328+ .NumInputs(1 )
1329+ .NumOutputs(1 , INT_MAX)
1330+ .SetDoc(R"DOC(
1331+ Given a packed dataset (packed by the PackRecordsOp) and the `fields` argument describing the datasets schema returns the original dataset format.
1332+ Number of returned tensors is equal to the number of fields in the `fields` argument.
1333+ )DOC" )
1334+ .Arg(
1335+ " fields" ,
1336+ " List of strings representing the string names in the format"
1337+ " specified in the doc for CreateTreeCursor." );
1338+
10511339SHOULD_NOT_DO_GRADIENT (CreateTreeCursor);
10521340SHOULD_NOT_DO_GRADIENT (ResetCursor);
10531341SHOULD_NOT_DO_GRADIENT (ReadNextBatch);
@@ -1060,9 +1348,12 @@ SHOULD_NOT_DO_GRADIENT(CreateTensorVector);
10601348SHOULD_NOT_DO_GRADIENT (TensorVectorSize);
10611349SHOULD_NOT_DO_GRADIENT (ConcatTensorVector);
10621350SHOULD_NOT_DO_GRADIENT (CollectTensor);
1351+ SHOULD_NOT_DO_GRADIENT (UnPack);
1352+ SHOULD_NOT_DO_GRADIENT (Pack);
10631353} // namespace
10641354CAFFE_KNOWN_TYPE (std::unique_ptr<TreeCursor>);
10651355CAFFE_KNOWN_TYPE (TensorVectorPtr<CPUContext>);
1356+ CAFFE_KNOWN_TYPE (SharedTensorVectorPtr);
10661357
10671358namespace {
10681359
0 commit comments