Lightweight Neural Network++ documentation

Visit the lwnnplus home page, download the technical report in english or italian.

Main Page | Class Hierarchy | Class List | File List | Class Members

trainer Class Reference

Class used for easily training a network with the standard techniques. More...

#include <trainer.h>

List of all members.

Public Member Functions

 trainer (network *net, const string &error_filename="", const string &accuracy_filename="")
 Constructor by a net.
 trainer (int inputs, int outputs, int hidden, const string &error_filename="", const string &accuracy_filename="")
 Constructor by number of inputs and outputs.
 trainer (const string &training_file, int hidden=0, const string &error_filename="", const string &accuracy_filename="", iomanage *iomanager=NULL)
 Constructor by specification file (training file).
 trainer (const trainer &t)
 Copy Constructor.
virtual ~trainer ()
 Destructor.
int get_current_epoch ()
 Retrieve current epoch.
void start_at_epoch (int epoch)
 Start counting the epochs from epoch.
void clear ()
 Set the current epoch at zero and clear best error and best epoch.
int get_max_epochs () const
 Get the max number of epochs.
void set_max_epochs (int nepochs)
 Set the max number of epochs.
float get_min_error () const
 Get the minimum error.
void set_min_error (float error)
 Set the min error.
bool set_wanted_accuracy (float accuracy)
 Set the wanted accuracy on validation set.
float get_wanted_accuracy ()
 Get wanted accuracy.
void set_error_filename (const string &filename)
 Set the filename of errors log.
string get_error_filename () const
 Get the filename of errors log.
void set_accuracy_filename (const string &filename)
 Set the filename of accuracy log.
string get_accuracy_filename () const
 Get filename of accuracy log.
networkget_network () const
 Get the network in the trainer.
bool set_network (network *newnet)
 Use a new network in the trainer.
bool set_network_best ()
 Use in the trainer the best network stored.
void set_training_validation (bool train_valid)
 Use validation for computing the error.
bool get_training_validation () const
 Are we computing the error on validation set?
void set_epochs_checking_error (int epochs)
 Set the number of epochs between a couple of validations on validation set.
int get_epochs_checking_error () const
 Get the number of epochs between a couple of validations on validation set.
int get_epochs_report () const
 Get the number of epochs between a couple of report on validation set.
void set_epochs_report (int e)
 Set the number of epochs between a couple of reports on validation set.
void set_stop_on_overfit (bool stop, int batches_increasing=0, float valid_higher=0.0)
 Make the learning stop on overfit.
bool get_stop_on_overfit () const
 Get stop on overfit mode.
float get_best_error () const
 Get error on validation set of the best net found.
int get_best_epoch () const
 Get the epoch when the best net was saved.
short int get_accuracy_mode () const
 Get the accuracy mode.
virtual void set_accuracy_mode (short int mode)
 Set the accuracy mode.
float get_accuracy_on_training () const
 Get accuracy on training set.
float get_accuracy_on_validation () const
 Get accuracy on validaton set.
float get_accuracy_on_certification () const
 Get accuracy on certification set.
float get_error_on_training () const
 Get error on training set.
float get_error_on_validation () const
 Get error on validation set.
float get_error_on_certification () const
 Get error on certification set.
int get_no_of_inputs () const
 Retrieve the number of inputs for the trainer.
int get_no_of_outputs () const
 Retrieve the number of inputs for the trainer.
short int get_stopping_cause () const
 Retrieve the cause of stopping of last training.
virtual string get_stopping_cause_string () const
 Retrieve a string describing the cause of stopping of last training.
void load_training (const string &filename)
 Load Training data from a file.
void load_validation (const string &filename)
 Load validation data from a file;.
void load_certification (const string &filename)
 Load certification data from a file.
bool using_best_net () const
 Are we saving the best net?
int train_online (bool verbose)
 Train network in on-line mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
int train_batch (bool verbose=false)
 Train network in batch mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
int train_ssab (bool verbose=false, bool reset_ssab=false)
 Train network in batch + Super SAB mode until it reaches min_error or the max number of epochs, or another stopping condition holds..
int train_shuffle (bool verbose=false)
 Train network in on-line mode with shuffle for presenting training set until it reaches min_error or the max number of epochs, or another stopping condition holds.
int train (int mode=0, bool verbose=false)
 Train the network in some mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
bool test ()
 Test error and accuracy on the training set.
bool validate ()
 Test error and accuracy on the validation set.
bool certificate ()
 Test error and accuracy on the certification set.
float show_results (bool verbose=false, int set=1) const
 Show results on the selected set computing the results on each of the inputs and comparing the results with the target.
void set_iomanager (iomanage *iomanager)
 Set iomanager for the trainer.
iomanageget_iomanager () const
 Get current iomanager. The iomanager is an object of a class derived from iomanage that manages your file type.
const traineroperator= (const trainer &other)
 Overloaded operator=.

Static Public Attributes

const short int NO_ACCURACY = 0
 Costant meaning "no accuracy".
const short int BINARY_ACCURACY = 1
 Constant for computing accuracy as O_i < .5 <==> T_i >.5.
const short int MAXGUESS_ACCURACY = 2
 Constant for computing accuracy as maxind(O) == maxind(T).
const short int STOPPING_UNDEFINED_CAUSE = 0
 Constant meaning there is no stopping cause defined.
const short int STOPPING_MAX_EPOCHS = 1
 Training stopped because max number of epochs was reached.
const short int STOPPING_MIN_ERROR_REACHED_ON_TRAINING = 2
 Training stopped because min erroor was reached on training set.
const short int STOPPING_MIN_ERROR_REACHED_ON_BEST_ERROR = 3
 Training stopped because min erroor was reached on best error on validation set (on stop_on_overfit mode).
const short int STOPPING_OVERFIT = 4
 Training stopped because of overfit conditions.
const short int STOPPING_MIN_ERROR_REACHED_ON_VALIDATION = 5
 Training stopped because min error was reached on validation set.
const short int STOPPING_WANTED_ACCURACY_REACHED = 6
 Training stopped because wanted accuracy was reached on validation set.

Protected Member Functions

virtual void print_input (float *input) const
 Print input data.
virtual void print_output (float *output) const
 Print output data.
virtual bool check (float *output, float *target) const
 Check if output is right according to target.
virtual bool continue_training (int start)
 Implementation of the stopping condition for the training.
void set_best_epoch (int epoch)
 Set the best epoch.
void set_best_error_on_validation (float error)
 Set the best error on validation set.
void set_stopping_cause (short int cause)
 Set the stopping cause.
networkget_best_net () const
 Get the best network.
void set_best_net (network *best)
 Set the best network.


Detailed Description

Class used for easily training a network with the standard techniques.

The trainer uses a network and holds three different sets of input / target pairs:

The format of files where input/output pairs are stored depend on the iomanager, an object of class iomanage you can set by the method set_iomanager(). By default the trainer uses a iomanager of class iomanagelwnnfann.

It is possible to write a class which inherits from trainer and which customizes some of its features. In particular it is possible to rewrite the code of method check() which implements the definition of accuracy, and of the method continue_training(), which implements the stopping condition.


Constructor & Destructor Documentation

trainer::trainer network net,
const string &  error_filename = "",
const string &  accuracy_filename = ""
 

Constructor by a net.

Parameters:
net Pointer to the network object to be trained
error_filename (default = "" means no log file) Name of log file for errors
accuracy_filename (default = "" means no log file) Name of log file for accuracy
If net==NULL throws a runtime_error exception.

trainer::trainer int  inputs,
int  outputs,
int  hidden,
const string &  error_filename = "",
const string &  accuracy_filename = ""
 

Constructor by number of inputs and outputs.

Parameters:
inputs Number of inputs
outputs Number of outputs
hidden Number of neurons in the hidden layer
error_filename (default = "" means no log file) Name of log file for errors
accuracy_filename (default = "" means no log file) Name of log file for accuracy
Creates a network with one hidden layer, with the logistic function as activation and uses it for training

trainer::trainer const string &  training_file,
int  hidden = 0,
const string &  error_filename = "",
const string &  accuracy_filename = "",
iomanage iomanager = NULL
 

Constructor by specification file (training file).

Parameters:
training_file 
hidden 
error_filename (default = "" means no log file) Name of log file for errors
accuracy_filename (default = "" means no log file) Name of log file for accuracy
iomanager the iomanager to be used for reading. If NULL (default) uses a constructor of class iomanagelwnnfann.
Number of inputs and outputs are written in the specifications. If hidden == 0 we take hidden = number of inputs

This constructor also constructs a network with 3 layers, and with right number od inputs and outputs.

trainer::trainer const trainer t  ) 
 

Copy Constructor.

Parameters:
t The trainer you want to copy

virtual trainer::~trainer  )  [virtual]
 

Destructor.

When a trainer object is destroyed, the network is kept, so you can train a network inside a trainer, then delete the trainer and keep the network. If you want to delete the trainer and the network, be sure to get the network by the get_network() method and destroy both the net and the trainer.


Member Function Documentation

int trainer::get_current_epoch  )  [inline]
 

Retrieve current epoch.

Returns:
number of times the network was trained on the training set
Current epoch can be set at zero with method clear(). It is set at zero also when load_training() is called

void trainer::start_at_epoch int  epoch  )  [inline]
 

Start counting the epochs from epoch.

Parameters:
epoch epoch number

void trainer::clear  ) 
 

Set the current epoch at zero and clear best error and best epoch.

This method also removes log files.

int trainer::get_max_epochs  )  const [inline]
 

Get the max number of epochs.

Returns:
max number of epochs

void trainer::set_max_epochs int  nepochs  ) 
 

Set the max number of epochs.

Parameters:
nepochs Max number of epochs
If a train* method is called more than once the number of epochs is not restarted from 0 but the max number of epoch is relative to every training session, so if you set max epochs to 1000 and train the network for 1000 epochs you can call again train and train the net for other 1000 epochs (starting from epoch 1001, of course), if another stopping condition does not hold before.

float trainer::get_min_error  )  const [inline]
 

Get the minimum error.

Returns:
Min error

void trainer::set_min_error float  error  )  [inline]
 

Set the min error.

Parameters:
error Min error
The min errror can be evaluated on training set or on validation set (see set_training_validation()) When this error is reached, training stops.

bool trainer::set_wanted_accuracy float  accuracy  ) 
 

Set the wanted accuracy on validation set.

Parameters:
accuracy (between 0 and 1). Values outside this interval disable wanted accuracy
Returns:
true if setting had success
Training stops when this accuracy is reached. Accuracy is relative to the validation set so you have to call load_validation() before.

Accuracy mode should also be setted before.

float trainer::get_wanted_accuracy  )  [inline]
 

Get wanted accuracy.

Returns:
accuracy between 0 and 1. -1 if it's not setted.
Training stops when this accuracy is reached. Accuracy is relative to the validation set

void trainer::set_error_filename const string &  filename  ) 
 

Set the filename of errors log.

Parameters:
filename string
Use filename="" to disable error logging

string trainer::get_error_filename  )  const [inline]
 

Get the filename of errors log.

Returns:
string

void trainer::set_accuracy_filename const string &  filename  ) 
 

Set the filename of accuracy log.

Parameters:
filename string
Use filename="" to disable accuracy logging

string trainer::get_accuracy_filename  )  const [inline]
 

Get filename of accuracy log.

Returns:
string

network * trainer::get_network  )  const [inline]
 

Get the network in the trainer.

Returns:
Network

bool trainer::set_network network newnet  ) 
 

Use a new network in the trainer.

Parameters:
newnet Pointer to a neural network
Returns:
true if newnet has the right number of inputs and outputs and has been setted, false otherwise (network is not changed)
If some training was made before, this method updates values of errors and accuracy for the new network

bool trainer::set_network_best  ) 
 

Use in the trainer the best network stored.

Returns:
true if operation had success
This method also deletes the old network, sets the current epoch to best epoch and error on validation set to best error.

Warning: old network is deleted after calling set_network_best() so you will need get_network() to retrieve the network!

You can't get direct access to the best network stored because of memory management. So if you need to work on network with overfit and on best network, use get_network() before calling this method, make a copy of that network

network* copy = new network(trainer->get_network());

then call set_network_best(), and get_network() will return the best network.

Warning 2: Before calling set_network_best():

After calling set_network_best(), if it returns true:

void trainer::set_training_validation bool  train_valid  ) 
 

Use validation for computing the error.

Parameters:
train_valid if true use validation set for computing the error if false use training set
The validation set is checked every a fixed number of epochs (see set_epochs_checking_error())

bool trainer::get_training_validation  )  const [inline]
 

Are we computing the error on validation set?

Returns:
if true the trainer is using the validation set to compute the error in the training. If false it is using the training set.
If true the validation set is checked every get_epochs_checking_error() epochs.

void trainer::set_epochs_checking_error int  epochs  ) 
 

Set the number of epochs between a couple of validations on validation set.

This value is used if get_training_validation() == true

int trainer::get_epochs_checking_error  )  const [inline]
 

Get the number of epochs between a couple of validations on validation set.

This value is used if get_training_validation() == true

int trainer::get_epochs_report  )  const [inline]
 

Get the number of epochs between a couple of report on validation set.

This is the interval between two reports in verbose mode

Returns:
number of epochs

void trainer::set_epochs_report int  e  ) 
 

Set the number of epochs between a couple of reports on validation set.

Parameters:
e number of epochs
This is the interval between two reports in verbose mode

void trainer::set_stop_on_overfit bool  stop,
int  batches_increasing = 0,
float  valid_higher = 0.0
 

Make the learning stop on overfit.

Parameters:
stop If true the trainer stops the training on overfit i.e. when error on validation set starts to increase. If false, normal conditions are checked
batches_increasing (only meaningful if stop == true) (optional) sets the number of batches the error must be increasing to detect overfit
valid_higher (only meaningful if stop == true) (optional) sets how much the error on validation set must be higher than the error on training set as a stopping condition for overfitting
The best net is the net with the minimum error on validation set The two following stopping conditions, if stop_on_overfit is set to true, must hold to stop training before max_epochs

  • For at least batches_increasing epochs the error on validation set gets higher
  • The error on validation set at least is valid_higher times greater than the error on training set

After the training you will have to call set_network_best() to use the best network found.

bool trainer::get_stop_on_overfit  )  const [inline]
 

Get stop on overfit mode.

Returns:
If true the trainer stops the training on overfit i.e. when error on validation set increases (see set_stop_on_overfit()). If false, normal conditions are checked

float trainer::get_best_error  )  const [inline]
 

Get error on validation set of the best net found.

Returns:
best error on validation set. -1 if not using_best_net()

int trainer::get_best_epoch  )  const [inline]
 

Get the epoch when the best net was saved.

Returns:
epoch number

short int trainer::get_accuracy_mode  )  const [inline]
 

Get the accuracy mode.

Returns:
mode

virtual void trainer::set_accuracy_mode short int  mode  )  [virtual]
 

Set the accuracy mode.

Parameters:
mode 
How to compute accuracy

Default accuracy mode is 1.

This method is virtual because you might need to rewrite it to give a different meaning to the accuracy mode (and to allow more than 3 modes).

float trainer::get_accuracy_on_training  )  const [inline]
 

Get accuracy on training set.

test() should have been called before

float trainer::get_accuracy_on_validation  )  const [inline]
 

Get accuracy on validaton set.

validate() should have been called before

float trainer::get_accuracy_on_certification  )  const [inline]
 

Get accuracy on certification set.

certificate() should have been called before

float trainer::get_error_on_training  )  const [inline]
 

Get error on training set.

test() should have been called before

float trainer::get_error_on_validation  )  const [inline]
 

Get error on validation set.

validate() should have been called before

float trainer::get_error_on_certification  )  const [inline]
 

Get error on certification set.

certificate() should have been called before

int trainer::get_no_of_inputs  )  const [inline]
 

Retrieve the number of inputs for the trainer.

Returns:
This must be the number of inputs of the network beinh trained and of every set of patterns.

int trainer::get_no_of_outputs  )  const [inline]
 

Retrieve the number of inputs for the trainer.

Returns:
This must be the number of output of the network being trained and of every set of patterns.

short int trainer::get_stopping_cause  )  const [inline]
 

Retrieve the cause of stopping of last training.

Returns:
code of the cause
Can be

virtual string trainer::get_stopping_cause_string  )  const [virtual]
 

Retrieve a string describing the cause of stopping of last training.

Returns:
string description
This method is declared virtual, so you can rewrit it if you have different stopping causes in a derived class which extends trainer.

void trainer::load_training const string &  filename  ) 
 

Load Training data from a file.

Parameters:
filename Filename to read from
Current epoch is set at 0 if the training set is loaded succesfully.

bool trainer::using_best_net  )  const [inline]
 

Are we saving the best net?

Returns:
true if best net is used, false otherwise

int trainer::train_online bool  verbose  ) 
 

Train network in on-line mode until it reaches min_error or the max number of epochs, or another stopping condition holds.

Parameters:
verbose If true writes errors (and accuracy, if an accuracy_mode is set) to stdout
Returns:
number of epochs of training in this session (will be different by get_current_epoch() if it's not the first training session)
After the training use get_stopping_cause() to know why training session ended

If current network is NULL throws a runtime_error exception

int trainer::train_batch bool  verbose = false  ) 
 

Train network in batch mode until it reaches min_error or the max number of epochs, or another stopping condition holds.

Parameters:
verbose If true writes errors (and accuracy, if an accuracy_mode is set) to stdout
Returns:
number of epochs of training in this session (will be different by get_current_epoch() if it's not the first training session)
After the training use get_stopping_cause() to know why training session ended

If current network is NULL throws a runtime_error exception

int trainer::train_ssab bool  verbose = false,
bool  reset_ssab = false
 

Train network in batch + Super SAB mode until it reaches min_error or the max number of epochs, or another stopping condition holds..

Parameters:
verbose If true writes errors (and accuracy, if an accuracy_mode is set) to stdout
reset_ssab (default = false) If set to true reset the learning rates of the last SuperSAB training to the global learning rate value
Returns:
number of epochs of training in this session (will be different by get_current_epoch() if it's not the first training session)
After the training use get_stopping_cause() to know why training session ended

If current network is NULL throws a runtime_error exception

int trainer::train_shuffle bool  verbose = false  ) 
 

Train network in on-line mode with shuffle for presenting training set until it reaches min_error or the max number of epochs, or another stopping condition holds.

Parameters:
verbose If true writes errors (and accuracy, if an accuracy_mode is set) to stdout
Returns:
number of epochs of training in this session (will be different by get_current_epoch() if it's not the first training session)
After the training use get_stopping_cause() to know why training session ended

If current network is NULL throws a runtime_error exception

int trainer::train int  mode = 0,
bool  verbose = false
 

Train the network in some mode until it reaches min_error or the max number of epochs, or another stopping condition holds.

Parameters:
mode 0 = online 1 = batch 2 = batch+ssab 3 = online with shuffle
verbose If true writes errors (and accuracy, if an accuracy_mode is set) to file
Returns:
number of epochs of training in this session (will be different by get_current_epoch() if it's not the first training session)
If current network is NULL throws a runtime_error exception

bool trainer::test  ) 
 

Test error and accuracy on the training set.

Returns:
true if the training set was previously loaded, false if not and could not compute the values
If current network is NULL throws a runtime_error exception

bool trainer::validate  ) 
 

Test error and accuracy on the validation set.

Returns:
true if the valdiation set was previously loaded, false if not and could not compute the values
If current network is NULL throws a runtime_error exception

bool trainer::certificate  ) 
 

Test error and accuracy on the certification set.

Returns:
true if the certification set was previously loaded, false if not and could not compute the values
If current network is NULL throws a runtime_error exception

float trainer::show_results bool  verbose = false,
int  set = 1
const
 

Show results on the selected set computing the results on each of the inputs and comparing the results with the target.

Parameters:
verbose If true shows all input/output pairs and targets If false shows only error and accuracy
set What set? 0 = training, 1 = validation (default), 2 = cerTification
Returns:
mean error
If current accuracy_mode == trainer::BINARY_ACCURACY or trainer::MAXGUESS_ACCURACY prints the number of right and wrong answers

If current network is NULL throws a runtime_error exception

void trainer::set_iomanager iomanage iomanager  ) 
 

Set iomanager for the trainer.

Parameters:
iomanager Pointer to an object of an iomanage-derived class
The iomanager is an object of a class derived from iomanage that manages your file type.

By default the trainer uses a iomanager of type iomanagelwnnfann.

If iomanager == NULL throws a runtime_error exception.

iomanage * trainer::get_iomanager  )  const [inline]
 

Get current iomanager. The iomanager is an object of a class derived from iomanage that manages your file type.

By default the trainer uses a iomanager of type iomanagelwnnfann.

const trainer& trainer::operator= const trainer other  ) 
 

Overloaded operator=.

Parameters:
other other trainer to be copied

virtual void trainer::print_input float *  input  )  const [protected, virtual]
 

Print input data.

Parameters:
input vector of input
A way for printing input data. This class provides a simple format as vector of float. You might want to make this method do something nicer deriving a class from trainer.

virtual void trainer::print_output float *  output  )  const [protected, virtual]
 

Print output data.

Parameters:
output vector of input
A way for printing input data. This class provides a simple format for printing as vector of float. You might want to make this method do something nicer deriving a class from trainer.

virtual bool trainer::check float *  output,
float *  target
const [protected, virtual]
 

Check if output is right according to target.

Parameters:
output The output of the net
target The right target
Returns:
true if output should be counted as right for accuracy, false otherwise.
This method implements the definition of accuracy. In standard implementation it depends on accuracy_mode (see set_accuracy_mode()) but you can rewrite it in a derived class of trainer in order to customize the definition of accuracy.

For example you could define a output to be right if the norm of the difference between output and target is lesser than a fixed threshold or, for a network which uses tanh as activaction function, you could use -1 and 1, instead of 0 and 1 for binary code.

virtual bool trainer::continue_training int  start  )  [protected, virtual]
 

Implementation of the stopping condition for the training.

Parameters:
start the starting epoch of the current training session
Returns:
true if training must continue, false if training must stop.
This method implements the stopping condition for all the train_* methods. If you want a different stopping condition you can write a derived class from trainer and rewrite this method.

The stopping condition has a lot of parameters in the standard implementation: be sure that you really need a different condition before starting to rewrite this method!

This is the most critical point in network training so... be careful! :)

void trainer::set_best_epoch int  epoch  )  [inline, protected]
 

Set the best epoch.

Parameters:
epoch number of epoch
This method is declared protected because you might need it in the implementation of continue_training()

void trainer::set_best_error_on_validation float  error  )  [inline, protected]
 

Set the best error on validation set.

Parameters:
error best error on validation set
This method is declared protected because you might need it in the implementation of continue_training()

void trainer::set_stopping_cause short int  cause  )  [inline, protected]
 

Set the stopping cause.

Parameters:
cause Code of stopping cause
This method is declared protected because you might need it in the implementation of continue_training()

network * trainer::get_best_net  )  const [inline, protected]
 

Get the best network.

Returns:
Pointer to the best network saved. NULL if not using_best_net()
This method is declared protected because you might need it in the implementation of continue_training()

void trainer::set_best_net network best  )  [inline, protected]
 

Set the best network.

Parameters:
best Pointer to the best network
This method is declared protected because you might need it in the implementation of continue_training()


The documentation for this class was generated from the following file:
Generated on Tue Oct 12 00:32:12 2004 for Lightweight Neural Network ++ by  doxygen 1.3.9