Concat

A primitive to concatenate data by arbitrary dimension.

API

struct dnnl::concat : public dnnl::primitive

Tensor concatenation (concat) primitive.

Public Functions

concat()

Default constructor. Produces an empty object.

concat(const primitive_desc &pd)

Constructs a concatenation primitive.

Parameters
  • pd: Primitive descriptor for concatenation primitive.

struct primitive_desc : public dnnl::primitive_desc_base

Primitive descriptor for a concat primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector<memory::desc> &srcs, const engine &engine, const primitive_attr &attr = primitive_attr())

Constructs a primitive descriptor for an out-of-place concatenation primitive.

Inputs:

Outputs:

Parameters
  • dst: Destination memory descriptor.

  • concat_dimension: Source tensors will be concatenated over dimension with this index. Note that order of dimensions does not depend on memory format.

  • srcs: Vector of source memory descriptors.

  • engine: Engine to perform the operation on.

  • attr: Primitive attributes to use (optional).

primitive_desc(int concat_dimension, const std::vector<memory::desc> &srcs, const engine &engine, const primitive_attr &attr = primitive_attr())

Constructs a primitive descriptor for an out-of-place concatenation primitive.

This version derives the destination memory descriptor automatically.

Parameters
  • concat_dimension: Source tensors will be concatenated over dimension with this index. Note that order of dimensions does not depend on memory format.

  • srcs: Vector of source memory descriptors.

  • engine: Engine to perform the operation on.

  • attr: Primitive attributes to use (optional).

memory::desc src_desc(int idx = 0) const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter with index pdx.

Parameters
  • idx: Source index.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.