Inner Product

The inner product primitive (sometimes called fully connected layer) treats each activation in the minibatch as a vector and computes its product with a weights 2D tensor producing a 2D tensor as an output.

Forward

Let \(\src\), \(\weights\), \(\bias\) and \(\dst\) be \(N \times IC\), \(OC \times IC\), \(OC\), and \(N \times OC\) tensors, respectively. Variable names follow the standard Conventions. Then:

\[\dst(n, oc) = \bias(oc) + \sum_{ic=0}^{IC-1} \src(n, ic) \cdot \weights(oc, ic)\]

In cases where the \(\src\) and \(\weights\) tensors have spatial dimensions, they are flattened to 2D. For example, if they are 4D \(N \times IC' \times IH \times IW\) and \(OC \times IC' \times KH \times KW\) tensors, then the formula above is applied with \(IC = IC' \cdot IH \cdot IW\). In such cases, the \(\src\) and \(\weights\) tensors must have equal spatial dimensions (e.g. \(KH = IH\) and \(KW = IW\) for 4D tensors).

Difference Between Forward Training and Forward Inference

There is no difference between the forward_training and forward_inference propagation kinds.

Backward

The backward propagation computes \(\diffsrc\) based on \(\diffdst\) and \(\weights\).

The weights update computes \(\diffweights\) and \(\diffbias\) based on \(\diffdst\) and \(\src\).

Note

The optimized memory formats \(\src\) and \(\weights\) might be different on forward propagation, backward propagation, and weights update.

Execution Arguments

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\weights\)

DNNL_ARG_WEIGHTS

\(\bias\)

DNNL_ARG_BIAS

\(\dst\)

DNNL_ARG_DST

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffweights\)

DNNL_ARG_DIFF_WEIGHTS

\(\diffbias\)

DNNL_ARG_DIFF_BIAS

\(\diffdst\)

DNNL_ARG_DIFF_DST

Operation Details

N/A

Data Types

Inner product primitive supports the following combination of data types for source, destination, weights, and bias:

Propagation

Source

Weights

Destination

Bias

forward / backward

f32

f32

f32

f32

forward

f16

f16

f16

f16

forward

u8, s8

s8

u8, s8, s32, f32

u8, s8, s32, f32

forward

bf16

bf16

f32, bf16

f32, bf16

backward

f32, bf16

bf16

bf16

weights update

bf16

f32, bf16

bf16

f32, bf16

Data Representation

Like other CNN primitives, the inner product primitive expects the following tensors:

Spatial

Source

Destination

Weights

1D

\(N \times C \times W\)

\(N \times C\)

\(OC \times IC \times KW\)

2D

\(N \times C \times H \times W\)

\(N \times C\)

\(OC \times IC \times KH \times KW\)

3D

\(N \times C \times D \times H \times W\)

\(N \times C\)

\(OC \times IC \times KD \times KH \times KW\)

Memory format of data and weights memory objects is critical for inner product primitive performance. In the oneDNN programming model, inner product primitive is one of the few primitives that support the placeholder format any and can define data and weight memory objects formats based on the primitive parameters. When using any it is necessary to first create an inner product primitive descriptor and then query it for the actual data and weight memory objects formats.

The table below shows the combinations for which plain memory formats the inner product primitive is optimized for. For the destination tensor (which is always \(N \times C\)) the memory format is always nc (ab).

Spatial

Source / Weights logical tensor

Implementation optimized for memory formats

0D

NC / OI

nc (ab) / oi (ab)

0D

NC / OI

nc (ab) / io (ba)

1D

NCW / OIW

ncw (abc) / oiw (abc)

1D

NCW / OIW

nwc (acb) / wio (cba)

2D

NCHW / OIHW

nchw (abcd) / oihw (abcd)

2D

NCHW / OIHW

nhwc (acdb) / hwio (cdba)

3D

NCDHW / OIDHW

ncdhw (abcde) / oidhw (abcde)

3D

NCDHW / OIDHW

ndhwc (acdeb) / dhwio (cdeba)

Post-ops and Attributes

The following post-ops should be supported by inner product primitives:

Propagation

Type

Operation

Description

Restrictions

forward

attribute

Output scale

Scales the result of inner product by given scale factor(s)

int8 inner products only

forward

post-op

Eltwise

Applies an elementwise operation to the result

forward

post-op

Sum

Adds the operation result to the destination tensor instead of overwriting it

API

struct dnnl::inner_product_forward : public dnnl::primitive

Inner product forward propagation primitive.

Public Functions

inner_product_forward()

Default constructor. Produces an empty object.

inner_product_forward(const primitive_desc &pd)

Constructs an inner product forward propagation primitive.

Parameters
  • pd: Primitive descriptor for an inner product forward propagation primitive.

struct desc

Descriptor for an inner product forward propagation primitive.

Public Functions

desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)

Constructs a descriptor for an inner product forward propagation primitive with bias.

Note

All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of format_tag.

Parameters

desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)

Constructs a descriptor for an inner product forward propagation primitive without bias.

Note

All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of format_tag.

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for an inner product forward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for an inner product forward propagation primitive.

Parameters
  • adesc: Descriptor for an inner product forward propagation primitive.

  • aengine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for an inner product forward propagation primitive.

Parameters
  • adesc: Descriptor for an inner product forward propagation primitive.

  • attr: Primitive attributes to use.

  • aengine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc weights_desc() const

Returns a weights memory descriptor.

Return

Weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc bias_desc() const

Returns the bias memory descriptor.

Return

The bias memory descriptor.

Return

A zero memory descriptor of the primitive does not have a bias parameter.

struct dnnl::inner_product_backward_data : public dnnl::primitive

Inner product backward propagation primitive.

Public Functions

inner_product_backward_data()

Default constructor. Produces an empty object.

inner_product_backward_data(const primitive_desc &pd)

Constructs an inner product backward propagation primitive.

Parameters
  • pd: Primitive descriptor for an inner product backward propagation primitive.

struct desc

Descriptor for an inner product backward propagation primitive.

Public Functions

desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)

Constructs a descriptor for an inner product backward propagation primitive.

Note

All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of format_tag.

Parameters
  • diff_src_desc: Memory descriptor for diff src.

  • weights_desc: Memory descriptor for weights.

  • diff_dst_desc: Memory descriptor for diff dst.

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for an inner product backward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an inner product backward propagation primitive.

Parameters
  • adesc: Descriptor for an inner product backward propagation primitive.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an inner product backward propagation primitive.

Parameters
  • adesc: Descriptor for an inner product backward propagation primitive.

  • attr: Primitive attributes to use.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc diff_src_desc() const

Returns a diff source memory descriptor.

Return

Diff source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff source memory with.

memory::desc weights_desc() const

Returns a weights memory descriptor.

Return

Weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Return

Diff destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff destination parameter.

struct dnnl::inner_product_backward_weights : public dnnl::primitive

Inner product weights gradient primitive.

Public Functions

inner_product_backward_weights()

Default constructor. Produces an empty object.

inner_product_backward_weights(const primitive_desc &pd)

Constructs an inner product weights gradient primitive.

Parameters
  • pd: Primitive descriptor for an inner product weights gradient primitive.

struct desc

Descriptor for an inner product weights gradient primitive.

Public Functions

desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)

Constructs a descriptor for an inner product descriptor weights update primitive with bias.

Note

All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of format_tag.

Parameters
  • src_desc: Memory descriptor for src.

  • diff_weights_desc: Memory descriptor for diff weights.

  • diff_bias_desc: Memory descriptor for diff bias.

  • diff_dst_desc: Memory descriptor for diff dst.

desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)

Constructs a descriptor for an inner product descriptor weights update primitive without bias.

Note

All the memory descriptors may be initialized with the dnnl::memory::format_tag::any value of format_tag.

Parameters
  • src_desc: Memory descriptor for src.

  • diff_weights_desc: Memory descriptor for diff weights.

  • diff_dst_desc: Memory descriptor for diff dst.

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for an inner product weights gradient primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an inner product weights update primitive.

Parameters
  • adesc: Descriptor for an inner product weights gradient primitive.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for an inner product weights update primitive.

Parameters
  • adesc: Descriptor for an inner product weights gradient primitive.

  • attr: Primitive attributes to use.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for an inner product forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc diff_weights_desc() const

Returns a diff weights memory descriptor.

Return

Diff weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff weights parameter.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Return

Diff destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc diff_bias_desc() const

Returns the diff bias memory descriptor.

Return

The diff bias memory descriptor.

Return

A zero memory descriptor of the primitive does not have a diff bias parameter.