Pooling

The pooling primitive performs forward or backward max or average pooling operation on 1D, 2D, or 3D spatial data.

The pooling operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Conventions.

Forward

Max pooling:

\[\dst(n, c, oh, ow) = \max\limits_{kh, kw} \left( \src(n, c, oh \cdot SH + kh - PH_L, ow \cdot SW +kw - PW_L) \right)\]

Average pooling:

\[\dst(n, c, oh, ow) = \frac{1}{DENOM} \sum\limits_{kh, kw} \src(n, c, oh \cdot SH + kh - PH_L, ow \cdot SW +kw - PW_L)\]

Here output spatial dimensions are calculated similarly to how they are done for Convolution and Deconvolution.

Average pooling supports two algorithms:

Difference Between Forward Training and Forward Inference

Max pooling requires a workspace for the forward_training propagation kind, and does not require it for forward_inference (see details below).

Backward

The backward propagation computes \(\diffsrc(n, c, h, w)\), based on \(\diffdst(n, c, h, w)\) and, in case of max pooling, workspace.

Execution Arguments

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.

Primitive input/output

Execution argument index

\(\src\)

DNNL_ARG_SRC

\(\dst\)

DNNL_ARG_DST

workspace

DNNL_ARG_WORKSPACE

\(\diffsrc\)

DNNL_ARG_DIFF_SRC

\(\diffdst\)

DNNL_ARG_DIFF_DST

Operation Details

  1. During training, max pooling requires a workspace on forward (forward_training) and backward passes to save indices where a maximum was found. The workspace format is opaque, and the indices cannot be restored from it. However, one can use backward pooling to perform up-sampling (used in some detection topologies). The workspace can be created via dnnl::pooling_forward::primitive_desc::workspace_desc().

  2. A user can use memory format tag any for dst memory descriptor when creating pooling forward propagation. The library would derive the appropriate format from the src memory descriptor. However, the src itself must be defined. Similarly, a user can use memory format tag any for the diff_src memory descriptor when creating pooling backward propagation.

Data Type Support

The pooling primitive supports the following combinations of data types:

Propagation

Source / Destination

Accumulation data type (used for average pooling only)

forward / backward

f32, bf16

f32

forward

f16

f16

forward

s8, u8, s32

s32

Data Representation

Source, Destination, and Their Gradients

Like other CNN primitives, the pooling primitive expects data to be an \(N \times C \times W\) tensor for the 1D spatial case, an \(N \times C \times H \times W\) tensor for the 2D spatial case, and an \(N \times C \times D \times H \times W\) tensor for the 3D spatial case.

The pooling primitive is optimized for the following memory formats:

Spatial

Logical tensor

Data type

Implementations optimized for memory formats

1D

NCW

f32

ncw (abc), nwc (acb), optimized^

1D

NCW

s32, s8, u8

nwc (acb), optimized^

2D

NCHW

f32

nchw (abcd), nhwc (acdb), optimized^

2D

NCHW

s32, s8, u8

nhwc (acdb), optimized^

3D

NCDHW

f32

ncdhw (abcde), ndhwc (acdeb), optimized^

3D

NCDHW

s32, s8, u8

ndhwc (acdeb), optimized^

Here optimized^ means the format that comes out of any preceding compute-intensive primitive.

Post-ops and Attributes

The pooling primitive does not support any post-ops or attributes.

API

struct dnnl::pooling_forward : public dnnl::primitive

Pooling forward propagation primitive.

Public Functions

pooling_forward()

Default constructor. Produces an empty object.

pooling_forward(const primitive_desc &pd)

Constructs a pooling forward propagation primitive.

Parameters
  • pd: Primitive descriptor for a pooling forward propagation primitive.

struct desc

Descriptor for a pooling forward propagation primitive.

Public Functions

desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)

Constructs a descriptor for pooling forward propagation primitive.

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for a pooling forward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for a pooling forward propagation primitive.

Parameters
  • adesc: Descriptor for a pooling forward propagation primitive.

  • aengine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty = false)

Constructs a primitive descriptor for a pooling forward propagation primitive.

Parameters
  • adesc: Descriptor for a pooling forward propagation primitive.

  • aengine: Engine to use.

  • attr: Primitive attributes to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Return

Workspace memory descriptor.

Return

A zero memory descriptor if the primitive does not require workspace parameter.

struct dnnl::pooling_backward : public dnnl::primitive

Pooling backward propagation primitive.

Public Functions

pooling_backward()

Default constructor. Produces an empty object.

pooling_backward(const primitive_desc &pd)

Constructs a pooling backward propagation primitive.

Parameters
  • pd: Primitive descriptor for a pooling backward propagation primitive.

struct desc

Descriptor for a pooling backward propagation primitive.

Public Functions

desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)

Constructs a descriptor for pooling backward propagation primitive.

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for a pooling backward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a pooling backward propagation primitive.

Parameters
  • adesc: Descriptor for a pooling backward propagation primitive.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a pooling forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a pooling backward propagation primitive.

Parameters
  • adesc: Descriptor for a pooling backward propagation primitive.

  • attr: Primitive attributes to use.

  • aengine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a pooling forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc diff_src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Return

Diff destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Return

Workspace memory descriptor.

Return

A zero memory descriptor if the primitive does not require workspace parameter.