IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
iganet::BatchNorm Class Reference

Batch Normalization as described in the paper. More...

#include </home/runner/work/iganet/iganet/include/layer.hpp>

Inheritance diagram for iganet::BatchNorm:
iganet::ActivationFunction iganet::utils::FullQualifiedName

Public Member Functions

 BatchNorm (const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool training=false)
 
 BatchNorm (const torch::Tensor &running_mean, const torch::Tensor &running_var, torch::nn::functional::BatchNormFuncOptions options={})
 
 ~BatchNorm () override=default
 
torch::Tensor apply (const torch::Tensor &input) const override
 Applies the activation function to the given input.
 
torch::nn::functional::BatchNormFuncOptions & options ()
 Returns non-constant reference to options.
 
const torch::nn::functional::BatchNormFuncOptions & options () const
 Returns constant reference to options.
 
virtual void pretty_print (std::ostream &os=Log(log::info)) const noexcept override
 Returns a string representation of the activation function.
 
torch::serialize::InputArchive & read (torch::serialize::InputArchive &archive, const std::string &key="batch_norm") override
 Reads the activation function from a torch::serialize::InputArchive object.
 
torch::Tensor & running_mean ()
 Returns non-constant reference to running mean.
 
const torch::Tensor & running_mean () const
 Returns constant reference to running mean.
 
torch::Tensor & running_var ()
 Returns non-constant reference to running var.
 
const torch::Tensor & running_var () const
 Returns constant reference to running variance.
 
torch::serialize::OutputArchive & write (torch::serialize::OutputArchive &archive, const std::string &key="batch_norm") const override
 Writes the activation function into a torch::serialize::OutputArchive object.
 
- Public Member Functions inherited from iganet::ActivationFunction
virtual ~ActivationFunction ()=default
 

Private Attributes

torch::nn::functional::BatchNormFuncOptions options_
 
torch::Tensor running_mean_
 
torch::Tensor running_var_
 

Additional Inherited Members

- Protected Member Functions inherited from iganet::utils::FullQualifiedName
virtual const std::string & name () const noexcept
 Returns the full qualified name of the object.
 
- Protected Attributes inherited from iganet::utils::FullQualifiedName
at::optional< std::string > name_
 String storing the full qualified name of the object.
 

Detailed Description

Batch Normalization as described in the paper.

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, https://arxiv.org/abs/1502.03167

Constructor & Destructor Documentation

◆ BatchNorm() [1/2]

iganet::BatchNorm::BatchNorm ( const torch::Tensor &  running_mean,
const torch::Tensor &  running_var,
torch::nn::functional::BatchNormFuncOptions  options = {} 
)
inlineexplicit

◆ BatchNorm() [2/2]

iganet::BatchNorm::BatchNorm ( const torch::Tensor &  running_mean,
const torch::Tensor &  running_var,
const torch::Tensor &  weight,
const torch::Tensor &  bias,
double  eps,
double  momentum,
bool  training = false 
)
inlineexplicit

◆ ~BatchNorm()

iganet::BatchNorm::~BatchNorm ( )
overridedefault

Member Function Documentation

◆ apply()

torch::Tensor iganet::BatchNorm::apply ( const torch::Tensor &  input) const
inlineoverridevirtual

Applies the activation function to the given input.

Implements iganet::ActivationFunction.

◆ options() [1/2]

torch::nn::functional::BatchNormFuncOptions & iganet::BatchNorm::options ( )
inline

Returns non-constant reference to options.

◆ options() [2/2]

const torch::nn::functional::BatchNormFuncOptions & iganet::BatchNorm::options ( ) const
inline

Returns constant reference to options.

◆ pretty_print()

virtual void iganet::BatchNorm::pretty_print ( std::ostream &  os = Log(log::info)) const
inlineoverridevirtualnoexcept

Returns a string representation of the activation function.

Implements iganet::ActivationFunction.

◆ read()

torch::serialize::InputArchive & iganet::BatchNorm::read ( torch::serialize::InputArchive &  archive,
const std::string &  key = "batch_norm" 
)
inlineoverridevirtual

Reads the activation function from a torch::serialize::InputArchive object.

Implements iganet::ActivationFunction.

◆ running_mean() [1/2]

torch::Tensor & iganet::BatchNorm::running_mean ( )
inline

Returns non-constant reference to running mean.

◆ running_mean() [2/2]

const torch::Tensor & iganet::BatchNorm::running_mean ( ) const
inline

Returns constant reference to running mean.

◆ running_var() [1/2]

torch::Tensor & iganet::BatchNorm::running_var ( )
inline

Returns non-constant reference to running var.

◆ running_var() [2/2]

const torch::Tensor & iganet::BatchNorm::running_var ( ) const
inline

Returns constant reference to running variance.

◆ write()

torch::serialize::OutputArchive & iganet::BatchNorm::write ( torch::serialize::OutputArchive &  archive,
const std::string &  key = "batch_norm" 
) const
inlineoverridevirtual

Writes the activation function into a torch::serialize::OutputArchive object.

Implements iganet::ActivationFunction.

Member Data Documentation

◆ options_

torch::nn::functional::BatchNormFuncOptions iganet::BatchNorm::options_
private

◆ running_mean_

torch::Tensor iganet::BatchNorm::running_mean_
private

◆ running_var_

torch::Tensor iganet::BatchNorm::running_var_
private

The documentation for this class was generated from the following file: