Layer Normalization

A primitive to perform layer normalization. Normalization is performed within the last logical dimension of data tensor. Both forward and backward propagation primitives support in-place operation; that is, src and dst can refer to the same memory for forward propagation, and diff_dst and diff_src can refer to the same memory for backward propagation. You can control the layer normalization primitives’ computations by specifying different dnnl::normalization_flags values. For example, you can configure layer normalization forward propagation to either compute the mean and variance or take them as arguments. Layer normalization can either perform scaling and shifting using gamma and beta parameters or not. Optionally, it can also perform a fused ReLU, which in case of training would also require a workspace.

API

struct dnnl::layer_normalization_forward : public dnnl::primitive

Layer normalization forward propagation primitive.

Public Functions

layer_normalization_forward()

Default constructor. Produces an empty object.

layer_normalization_forward(const primitive_desc &pd)

Constructs a layer normalization forward propagation primitive.

Parameters
  • pd: Primitive descriptor for a layer normalization forward propagation primitive.

struct desc

Descriptor for a layer normalization forward propagation primitive.

Public Functions

desc(prop_kind prop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)

Constructs a descriptor for layer normalization forward propagation primitive.

Inputs:

Outputs:

Parameters

desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)

Constructs a descriptor for layer normalization forward propagation primitive.

Inputs:

Outputs:

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for a layer normalization forward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &desc, const engine &engine, bool allow_empty = false)

Constructs a primitive descriptor for a layer normalization forward propagation primitive.

Parameters
  • desc: Descriptor for a layer normalization forward propagation primitive.

  • engine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, bool allow_empty = false)

Constructs a primitive descriptor for a layer normalization forward propagation primitive.

Parameters
  • desc: Descriptor for a layer normalization forward propagation primitive.

  • attr: Primitive attributes to use.

  • engine: Engine to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc weights_desc() const

Returns a weights memory descriptor.

Return

Weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Return

Workspace memory descriptor.

Return

A zero memory descriptor if the primitive does not require workspace parameter.

memory::desc mean_desc() const

Returns memory descriptor for mean.

Return

Memory descriptor for mean.

memory::desc variance_desc() const

Returns memory descriptor for variance.

Return

Memory descriptor for variance.

struct dnnl::layer_normalization_backward : public dnnl::primitive

Layer normalization backward propagation primitive.

Public Functions

layer_normalization_backward()

Default constructor. Produces an empty object.

layer_normalization_backward(const primitive_desc &pd)

Constructs a layer normalization backward propagation primitive.

Parameters
  • pd: Primitive descriptor for a layer normalization backward propagation primitive.

struct desc

Descriptor for a layer normalization backward propagation primitive.

Public Functions

desc(prop_kind prop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)

Constructs a descriptor for layer normalization backward propagation primitive.

Inputs:

Outputs:

Parameters
  • prop_kind: Propagation kind. Possible values are dnnl::prop_kind::backward_data and dnnl::prop_kind::backward (diffs for all parameters are computed in this case).

  • diff_data_desc: Diff source and diff destination memory descriptor.

  • data_desc: Source memory descriptor.

  • stat_desc: Statistics memory descriptors.

  • epsilon: Layer normalization epsilon parameter.

  • flags: Layer normalization flags (dnnl::normalization_flags).

desc(prop_kind prop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)

Constructs a descriptor for layer normalization backward propagation primitive.

Inputs:

Outputs:

Parameters

struct primitive_desc : public dnnl::primitive_desc

Primitive descriptor for a layer normalization backward propagation primitive.

Public Functions

primitive_desc()

Default constructor. Produces an empty object.

primitive_desc(const desc &desc, const engine &engine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a layer normalization backward propagation primitive.

Parameters
  • desc: Descriptor for a layer normalization backward propagation primitive.

  • engine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a layer normalization forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

primitive_desc(const desc &desc, const primitive_attr &attr, const engine &engine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty = false)

Constructs a primitive descriptor for a layer normalization backward propagation primitive.

Parameters
  • desc: Descriptor for a layer normalization backward propagation primitive.

  • attr: Primitive attributes to use.

  • engine: Engine to use.

  • hint_fwd_pd: Primitive descriptor for a layer normalization forward propagation primitive. It is used as a hint for deciding which memory format to use.

  • allow_empty: A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false.

memory::desc src_desc() const

Returns a source memory descriptor.

Return

Source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a source parameter.

memory::desc weights_desc() const

Returns a weights memory descriptor.

Return

Weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a weights parameter.

memory::desc dst_desc() const

Returns a destination memory descriptor.

Return

Destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a destination parameter.

memory::desc diff_src_desc() const

Returns a diff source memory descriptor.

Return

Diff source memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff source memory with.

memory::desc diff_dst_desc() const

Returns a diff destination memory descriptor.

Return

Diff destination memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff destination parameter.

memory::desc diff_weights_desc() const

Returns a diff weights memory descriptor.

Return

Diff weights memory descriptor.

Return

A zero memory descriptor if the primitive does not have a diff weights parameter.

memory::desc mean_desc() const

Returns memory descriptor for mean.

Return

Memory descriptor for mean.

memory::desc variance_desc() const

Returns memory descriptor for variance.

Return

Memory descriptor for variance.

memory::desc workspace_desc() const

Returns the workspace memory descriptor.

Return

Workspace memory descriptor.

Return

A zero memory descriptor if the primitive does not require workspace parameter.