Visit the lwnnplus home page, download the technical report in english or italian.
#include <trainer.h>
Public Member Functions | |
trainer (network *net, const string &error_filename="", const string &accuracy_filename="") | |
Constructor by a net. | |
trainer (int inputs, int outputs, int hidden, const string &error_filename="", const string &accuracy_filename="") | |
Constructor by number of inputs and outputs. | |
trainer (const string &training_file, int hidden=0, const string &error_filename="", const string &accuracy_filename="", iomanage *iomanager=NULL) | |
Constructor by specification file (training file). | |
trainer (const trainer &t) | |
Copy Constructor. | |
virtual | ~trainer () |
Destructor. | |
int | get_current_epoch () |
Retrieve current epoch. | |
void | start_at_epoch (int epoch) |
Start counting the epochs from epoch. | |
void | clear () |
Set the current epoch at zero and clear best error and best epoch. | |
int | get_max_epochs () const |
Get the max number of epochs. | |
void | set_max_epochs (int nepochs) |
Set the max number of epochs. | |
float | get_min_error () const |
Get the minimum error. | |
void | set_min_error (float error) |
Set the min error. | |
bool | set_wanted_accuracy (float accuracy) |
Set the wanted accuracy on validation set. | |
float | get_wanted_accuracy () |
Get wanted accuracy. | |
void | set_error_filename (const string &filename) |
Set the filename of errors log. | |
string | get_error_filename () const |
Get the filename of errors log. | |
void | set_accuracy_filename (const string &filename) |
Set the filename of accuracy log. | |
string | get_accuracy_filename () const |
Get filename of accuracy log. | |
network * | get_network () const |
Get the network in the trainer. | |
bool | set_network (network *newnet) |
Use a new network in the trainer. | |
bool | set_network_best () |
Use in the trainer the best network stored. | |
void | set_training_validation (bool train_valid) |
Use validation for computing the error. | |
bool | get_training_validation () const |
Are we computing the error on validation set? | |
void | set_epochs_checking_error (int epochs) |
Set the number of epochs between a couple of validations on validation set. | |
int | get_epochs_checking_error () const |
Get the number of epochs between a couple of validations on validation set. | |
int | get_epochs_report () const |
Get the number of epochs between a couple of report on validation set. | |
void | set_epochs_report (int e) |
Set the number of epochs between a couple of reports on validation set. | |
void | set_stop_on_overfit (bool stop, int batches_increasing=0, float valid_higher=0.0) |
Make the learning stop on overfit. | |
bool | get_stop_on_overfit () const |
Get stop on overfit mode. | |
float | get_best_error () const |
Get error on validation set of the best net found. | |
int | get_best_epoch () const |
Get the epoch when the best net was saved. | |
short int | get_accuracy_mode () const |
Get the accuracy mode. | |
virtual void | set_accuracy_mode (short int mode) |
Set the accuracy mode. | |
float | get_accuracy_on_training () const |
Get accuracy on training set. | |
float | get_accuracy_on_validation () const |
Get accuracy on validaton set. | |
float | get_accuracy_on_certification () const |
Get accuracy on certification set. | |
float | get_error_on_training () const |
Get error on training set. | |
float | get_error_on_validation () const |
Get error on validation set. | |
float | get_error_on_certification () const |
Get error on certification set. | |
int | get_no_of_inputs () const |
Retrieve the number of inputs for the trainer. | |
int | get_no_of_outputs () const |
Retrieve the number of inputs for the trainer. | |
short int | get_stopping_cause () const |
Retrieve the cause of stopping of last training. | |
virtual string | get_stopping_cause_string () const |
Retrieve a string describing the cause of stopping of last training. | |
void | load_training (const string &filename) |
Load Training data from a file. | |
void | load_validation (const string &filename) |
Load validation data from a file;. | |
void | load_certification (const string &filename) |
Load certification data from a file. | |
bool | using_best_net () const |
Are we saving the best net? | |
int | train_online (bool verbose) |
Train network in on-line mode until it reaches min_error or the max number of epochs, or another stopping condition holds. | |
int | train_batch (bool verbose=false) |
Train network in batch mode until it reaches min_error or the max number of epochs, or another stopping condition holds. | |
int | train_ssab (bool verbose=false, bool reset_ssab=false) |
Train network in batch + Super SAB mode until it reaches min_error or the max number of epochs, or another stopping condition holds.. | |
int | train_shuffle (bool verbose=false) |
Train network in on-line mode with shuffle for presenting training set until it reaches min_error or the max number of epochs, or another stopping condition holds. | |
int | train (int mode=0, bool verbose=false) |
Train the network in some mode until it reaches min_error or the max number of epochs, or another stopping condition holds. | |
bool | test () |
Test error and accuracy on the training set. | |
bool | validate () |
Test error and accuracy on the validation set. | |
bool | certificate () |
Test error and accuracy on the certification set. | |
float | show_results (bool verbose=false, int set=1) const |
Show results on the selected set computing the results on each of the inputs and comparing the results with the target. | |
void | set_iomanager (iomanage *iomanager) |
Set iomanager for the trainer. | |
iomanage * | get_iomanager () const |
Get current iomanager. The iomanager is an object of a class derived from iomanage that manages your file type. | |
const trainer & | operator= (const trainer &other) |
Overloaded operator=. | |
Static Public Attributes | |
const short int | NO_ACCURACY = 0 |
Costant meaning "no accuracy". | |
const short int | BINARY_ACCURACY = 1 |
Constant for computing accuracy as O_i < .5 <==> T_i >.5. | |
const short int | MAXGUESS_ACCURACY = 2 |
Constant for computing accuracy as maxind(O) == maxind(T). | |
const short int | STOPPING_UNDEFINED_CAUSE = 0 |
Constant meaning there is no stopping cause defined. | |
const short int | STOPPING_MAX_EPOCHS = 1 |
Training stopped because max number of epochs was reached. | |
const short int | STOPPING_MIN_ERROR_REACHED_ON_TRAINING = 2 |
Training stopped because min erroor was reached on training set. | |
const short int | STOPPING_MIN_ERROR_REACHED_ON_BEST_ERROR = 3 |
Training stopped because min erroor was reached on best error on validation set (on stop_on_overfit mode). | |
const short int | STOPPING_OVERFIT = 4 |
Training stopped because of overfit conditions. | |
const short int | STOPPING_MIN_ERROR_REACHED_ON_VALIDATION = 5 |
Training stopped because min error was reached on validation set. | |
const short int | STOPPING_WANTED_ACCURACY_REACHED = 6 |
Training stopped because wanted accuracy was reached on validation set. | |
Protected Member Functions | |
virtual void | print_input (float *input) const |
Print input data. | |
virtual void | print_output (float *output) const |
Print output data. | |
virtual bool | check (float *output, float *target) const |
Check if output is right according to target. | |
virtual bool | continue_training (int start) |
Implementation of the stopping condition for the training. | |
void | set_best_epoch (int epoch) |
Set the best epoch. | |
void | set_best_error_on_validation (float error) |
Set the best error on validation set. | |
void | set_stopping_cause (short int cause) |
Set the stopping cause. | |
network * | get_best_net () const |
Get the best network. | |
void | set_best_net (network *best) |
Set the best network. |
The trainer uses a network and holds three different sets of input / target pairs:
The format of files where input/output pairs are stored depend on the iomanager, an object of class iomanage you can set by the method set_iomanager(). By default the trainer uses a iomanager of class iomanagelwnnfann.
It is possible to write a class which inherits from trainer and which customizes some of its features. In particular it is possible to rewrite the code of method check() which implements the definition of accuracy, and of the method continue_training(), which implements the stopping condition.
|
Constructor by a net.
|
|
Constructor by number of inputs and outputs.
|
|
Constructor by specification file (training file).
This constructor also constructs a network with 3 layers, and with right number od inputs and outputs. |
|
Copy Constructor.
|
|
Destructor. When a trainer object is destroyed, the network is kept, so you can train a network inside a trainer, then delete the trainer and keep the network. If you want to delete the trainer and the network, be sure to get the network by the get_network() method and destroy both the net and the trainer. |
|
Retrieve current epoch.
|
|
Start counting the epochs from epoch.
|
|
Set the current epoch at zero and clear best error and best epoch. This method also removes log files. |
|
Get the max number of epochs.
|
|
Set the max number of epochs.
|
|
Get the minimum error.
|
|
Set the min error.
|
|
Set the wanted accuracy on validation set.
Accuracy mode should also be setted before. |
|
Get wanted accuracy.
|
|
Set the filename of errors log.
|
|
Get the filename of errors log.
|
|
Set the filename of accuracy log.
|
|
Get filename of accuracy log.
|
|
Get the network in the trainer.
|
|
Use a new network in the trainer.
|
|
Use in the trainer the best network stored.
Warning: old network is deleted after calling set_network_best() so you will need get_network() to retrieve the network! You can't get direct access to the best network stored because of memory management. So if you need to work on network with overfit and on best network, use get_network() before calling this method, make a copy of that network network* copy = new network(trainer->get_network()); then call set_network_best(), and get_network() will return the best network. Warning 2: Before calling set_network_best():
After calling set_network_best(), if it returns true:
|
|
Use validation for computing the error.
|
|
Are we computing the error on validation set?
|
|
Set the number of epochs between a couple of validations on validation set. This value is used if get_training_validation() == true |
|
Get the number of epochs between a couple of validations on validation set. This value is used if get_training_validation() == true |
|
Get the number of epochs between a couple of report on validation set. This is the interval between two reports in verbose mode
|
|
Set the number of epochs between a couple of reports on validation set.
|
|
Make the learning stop on overfit.
After the training you will have to call set_network_best() to use the best network found. |
|
Get stop on overfit mode.
|
|
Get error on validation set of the best net found.
|
|
Get the epoch when the best net was saved.
|
|
Get the accuracy mode.
|
|
Set the accuracy mode.
Default accuracy mode is 1. This method is virtual because you might need to rewrite it to give a different meaning to the accuracy mode (and to allow more than 3 modes). |
|
Get accuracy on training set. test() should have been called before |
|
Get accuracy on validaton set. validate() should have been called before |
|
Get accuracy on certification set. certificate() should have been called before |
|
Get error on training set. test() should have been called before |
|
Get error on validation set. validate() should have been called before |
|
Get error on certification set. certificate() should have been called before |
|
Retrieve the number of inputs for the trainer.
|
|
Retrieve the number of inputs for the trainer.
|
|
Retrieve the cause of stopping of last training.
|
|
Retrieve a string describing the cause of stopping of last training.
|
|
Load Training data from a file.
|
|
Are we saving the best net?
|
|
Train network in on-line mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
If current network is NULL throws a runtime_error exception |
|
Train network in batch mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
If current network is NULL throws a runtime_error exception |
|
Train network in batch + Super SAB mode until it reaches min_error or the max number of epochs, or another stopping condition holds..
If current network is NULL throws a runtime_error exception |
|
Train network in on-line mode with shuffle for presenting training set until it reaches min_error or the max number of epochs, or another stopping condition holds.
If current network is NULL throws a runtime_error exception |
|
Train the network in some mode until it reaches min_error or the max number of epochs, or another stopping condition holds.
|
|
Test error and accuracy on the training set.
|
|
Test error and accuracy on the validation set.
|
|
Test error and accuracy on the certification set.
|
|
Show results on the selected set computing the results on each of the inputs and comparing the results with the target.
If current network is NULL throws a runtime_error exception |
|
Set iomanager for the trainer.
By default the trainer uses a iomanager of type iomanagelwnnfann. If iomanager == NULL throws a runtime_error exception. |
|
Get current iomanager. The iomanager is an object of a class derived from iomanage that manages your file type. By default the trainer uses a iomanager of type iomanagelwnnfann. |
|
Overloaded operator=.
|
|
Print input data.
|
|
Print output data.
|
|
Check if output is right according to target.
For example you could define a output to be right if the norm of the difference between output and target is lesser than a fixed threshold or, for a network which uses tanh as activaction function, you could use -1 and 1, instead of 0 and 1 for binary code. |
|
Implementation of the stopping condition for the training.
The stopping condition has a lot of parameters in the standard implementation: be sure that you really need a different condition before starting to rewrite this method! This is the most critical point in network training so... be careful! :) |
|
Set the best epoch.
|
|
Set the best error on validation set.
|
|
Set the stopping cause.
|
|
Get the best network.
|
|
Set the best network.
|