Lightweight Neural Network++ documentation

Visit the lwnnplus home page, download the technical report in english or italian.

Main Page | Class Hierarchy | Class List | File List | Class Members

trainer.h

00001 #ifndef TRAINER_H
00002 #define TRAINER_H
00003 /*
00004  * Lightweight Neural Net ++ - Trainer class
00005  * http://lwneuralnetplus.sourceforge.net/
00006  *
00007  * This C++ library provides the class trainer wich implements a 
00008  * the most used techniques for training a network
00009  *
00010  * By Lorenzo Masetti <lorenzo.masetti@libero.it> and Luca Cinti <lucacinti@supereva.it>
00011  *
00012  * This library is free software; you can redistribute it and/or
00013  * modify it under the terms of the GNU Lesser General Public
00014  * License as published by the Free Software Foundation; either
00015  * version 2.1 of the License, or (at your option) any later version.
00016  * 
00017  */
00018 #include <string>
00019 #include <stdexcept>
00020 #include <fstream>
00021 using namespace std;
00022 
00023 #include "network.h"
00024 #include "iomanage.h"
00025 #include "iomanagelwnnfann.h"
00026 
00046 class trainer
00047 {
00048 public:
00049   /* Constructors */
00050 
00058   trainer (network * net, const string & error_filename =
00059            "", const string & accuracy_filename = "");
00060 
00071     trainer (int inputs, int outputs, int hidden,
00072              const string & error_filename =
00073              "", const string & accuracy_filename = "");
00074 
00088     trainer (const string & training_file, int hidden =
00089              0, const string & error_filename =
00090              "", const string & accuracy_filename = "",
00091              iomanage* iomanager = NULL);
00092 
00096     trainer (const trainer& t);
00097 
00098 
00108     virtual ~trainer ();
00109 
00110   /* Accessors and Mutators */
00111 
00112 
00121   int get_current_epoch();
00122   
00126   void start_at_epoch(int epoch);
00127 
00132   void clear();
00133 
00134 
00138   int get_max_epochs () const;
00139 
00153   void set_max_epochs (int nepochs);
00154 
00158   float get_min_error () const;
00159 
00167   void set_min_error (float error);
00168 
00181   bool set_wanted_accuracy( float accuracy);
00182 
00183 
00190   float get_wanted_accuracy();
00191 
00192 
00198   void set_error_filename (const string & filename);
00199 
00203   string get_error_filename() const;
00204 
00205 
00211   void set_accuracy_filename (const string & filename);
00212 
00216   string get_accuracy_filename() const;
00217 
00221   network *get_network () const;
00222 
00232   bool set_network (network * newnet);
00233 
00276   bool set_network_best();
00277 
00278 
00286   void set_training_validation (bool train_valid);
00287 
00296   bool get_training_validation () const;
00297 
00303   void set_epochs_checking_error (int epochs);
00304 
00310   int get_epochs_checking_error () const;
00311 
00312 
00319   int get_epochs_report () const;
00320 
00327   void  set_epochs_report (int e);
00328 
00352   void set_stop_on_overfit (bool stop, int batches_increasing =
00353                             0, float valid_higher = 0.0);
00354 
00361   bool get_stop_on_overfit () const;
00362 
00363 
00367   float get_best_error () const;
00368 
00372   int get_best_epoch () const;
00373 
00380   short int get_accuracy_mode () const;
00381 
00396   virtual void set_accuracy_mode (short int mode);
00397 
00398   /* Errors and accuracies Accessors */
00399 
00404   float get_accuracy_on_training () const;
00405 
00410   float get_accuracy_on_validation () const;
00411 
00416   float get_accuracy_on_certification () const;
00417 
00422   float get_error_on_training () const;
00423 
00424 
00429   float get_error_on_validation () const;
00430 
00431 
00436   float get_error_on_certification () const;
00437 
00442   int get_no_of_inputs () const;
00443 
00444 
00449   int get_no_of_outputs () const;
00450 
00462   short int get_stopping_cause() const;
00463 
00470   virtual string get_stopping_cause_string() const;
00471 
00472 
00473   /* Methods to load training, validation, certification set */
00474 
00481   void load_training (const string & filename);
00482 
00485   void load_validation (const string & filename);
00486 
00489   void load_certification (const string & filename);
00490 
00491 
00492 
00496   bool using_best_net () const;
00497 
00511   int train_online (bool verbose);
00512 
00513 
00527   int train_batch (bool verbose=false);
00528 
00529 
00545   int train_ssab (bool verbose=false, bool reset_ssab = false);
00546 
00562   int train_shuffle(bool verbose=false);
00563 
00575   int train (int mode=0, bool verbose=false);
00576 
00577 
00585   bool test ();
00586 
00593   bool validate ();
00594 
00601   bool certificate ();
00602 
00614   float show_results (bool verbose=false, int set = 1) const;
00615 
00629   void set_iomanager(iomanage* iomanager);
00630  
00639   iomanage* get_iomanager() const;
00640 
00644   const trainer& operator= (const trainer& other);
00645 
00647   static const short int NO_ACCURACY = 0;
00649   static const short int BINARY_ACCURACY = 1;
00650 
00652   static const short int MAXGUESS_ACCURACY = 2;
00653 
00655   static const short int STOPPING_UNDEFINED_CAUSE = 0;
00657   static const short int STOPPING_MAX_EPOCHS = 1;
00659   static const short int STOPPING_MIN_ERROR_REACHED_ON_TRAINING = 2;
00661   static const short int STOPPING_MIN_ERROR_REACHED_ON_BEST_ERROR = 3;
00663   static const short int STOPPING_OVERFIT = 4;
00665   static const short int STOPPING_MIN_ERROR_REACHED_ON_VALIDATION = 5;
00667   static const short int  STOPPING_WANTED_ACCURACY_REACHED = 6;
00668 protected:
00677   virtual void print_input(float* input) const;
00678 
00687   virtual void print_output(float* output) const;
00688 
00708   virtual bool check (float *output, float *target) const;
00709 
00727   virtual bool continue_training (int start);
00728 
00729   /* These methods are declared protected because you might
00730    * need them in the implementation of continue_training()
00731    */ 
00732 
00739   void set_best_epoch(int epoch);
00740   
00741 
00748   void set_best_error_on_validation(float error);
00749 
00756   void set_stopping_cause(short int cause);
00757 
00764   network *get_best_net () const;
00765 
00766 
00773   void set_best_net(network* best);
00774 
00775 private:
00776   /* Print accuracy on file */
00777   void print_accuracy () const;
00778   void set_defaults (bool construct_iomanager = true);
00779   void destroy (int n, float **input, float **output);
00780   void allocate_data (int npattern, float **&input, float **&target);
00781   void verbose_report ();
00782 
00783   float get_accuracy (int npatterns, float **input, float **target) const;
00784   float get_error (int npatterns, float **input, float **target) const;
00785   void compute_error_and_accuracy (float &error, float &accuracy,
00786                                    int npatterns, float **input,
00787                                    float **target);
00788 
00789   float compute_accuracy(int npatterns, float** input, float** target);
00790 
00791   void compute_accuracy_on_training();
00792 
00793   void reports (bool verbose);
00794 
00795   float show_on_set (bool verbose,int npatterns, float **input, float **target) const;
00796 
00797   static void print_vector(float* v, int n);
00798   void check_files();
00799 
00800   void copy_data(int n, float** input, float** target, float** srcinput, float** srctarget );
00801   void copy(const trainer& t);
00802   void free_trainer();
00803   
00804 #ifdef DEBUG
00805   void print_input (float **, int np) const;
00806   void print_target (float **patt, int np) const;
00807 #endif
00808 
00809 
00810   /* Input length */
00811   int ninput;
00812   /* Output length */
00813   int noutput;
00814   /* Number of trainig patterns */
00815   int npattern_training;
00816   /* Number of validation patterns */
00817   int npattern_validation;
00818   /* Number of certification patterns */
00819   int npattern_certification;
00820 
00821   /* current epoch */
00822   int epoch;
00823   
00824   /* computing error on validation set? */
00825   bool train_valid;
00826   /* each this number of epochs error is computed on validation set */
00827   int epochs_checking_error;
00828 
00829   /* epochs of report */
00830   int epochs_report;
00831 
00832   /* stop on overfit */
00833   bool stop_on_overfit;
00834   /* how much higher must be the error on validation set for stopping ? */
00835   float _higher_valid;
00836   short int accuracy_mode;
00837 
00838   /* Wanted accuracy on validation set. Training stops if this accuracy is reached */
00839   float wanted_accuracy;
00840 
00841 
00842   /* how many batches should pass with validation error increasing to
00843    * suppose overfitting?
00844    */
00845   int _batches_not_saving;
00846   /* epoch of the best net */
00847   int _best_epoch;
00848   
00849   /* stopping cause */
00850   short int stopping_cause;
00851 
00852   float error_on_training;
00853   float error_on_validation;
00854   float best_error_on_validation;
00855   float error_on_certification;
00856   float accuracy_on_training;
00857   float accuracy_on_validation;
00858   float accuracy_on_certification;
00859 
00860   /* Training set */
00861   float **training_set_input;
00862   float **training_set_target;
00863   /* Validation set */
00864   float **validation_set_input;
00865   float **validation_set_target;
00866   /* Certification set */
00867   float **certification_set_input;
00868   float **certification_set_target;
00869   /* Max number of epochs */
00870   int max_epochs;
00871   /* Min average error */
00872   float min_error;
00873   /* File used to store error trend */
00874   string error_filename;
00875   /* File used to store accuracy trend */
00876   string accuracy_filename;
00877   /* Neural network to train */
00878   network *net;
00879 
00880   /* Best Neural Network */
00881   network* best_net;
00882 
00883   /* Iomanager used to read patterns */
00884   iomanage* iomanager;
00885 };
00886 
00887 /****************************************
00888  * IMPLEMENTATION OF INLINE FUNCTIONS
00889  * ACCESSORS AND MUTATORS
00890  ****************************************/
00891 
00892 inline float
00893 trainer::get_min_error () const
00894 {
00895   return min_error;
00896 }
00897 
00898 
00899 inline int
00900 trainer::get_max_epochs () const
00901 {
00902   return max_epochs;
00903 }
00904 
00905 inline network * trainer::get_network () const
00906 {
00907   return net;
00908 }
00909 
00910 inline void
00911 trainer::set_min_error (float error)
00912 {
00913   min_error = error;
00914 }
00915 
00916 inline bool trainer::get_training_validation () const
00917 {
00918   return train_valid;
00919 }
00920 
00921 inline bool trainer::using_best_net () const
00922 {
00923   return (best_net != NULL);
00924 }
00925 
00926 
00927 inline float
00928 trainer::get_best_error () const
00929 {
00930   return best_error_on_validation;
00931 }
00932 
00933 
00934 
00935 
00936 inline bool trainer::get_stop_on_overfit () const
00937 {
00938   return stop_on_overfit;
00939 }
00940 
00941 inline int
00942 trainer::get_epochs_checking_error () const
00943 {
00944   return epochs_checking_error;
00945 }
00946 
00947 inline
00948 int trainer::get_epochs_report () const {
00949   return epochs_report;
00950 }
00951 
00952 inline int
00953 trainer::get_best_epoch () const
00954 {
00955   return _best_epoch;
00956 }
00957 
00958 inline short int
00959 trainer::get_accuracy_mode () const
00960 {
00961   return accuracy_mode;
00962 }
00963 
00964 inline float
00965 trainer::get_error_on_training () const
00966 {
00967   return error_on_training;
00968 }
00969 
00970 inline float
00971 trainer::get_error_on_validation () const
00972 {
00973   return error_on_validation;
00974 }
00975 
00976 inline float
00977 trainer::get_error_on_certification () const
00978 {
00979   return error_on_certification;
00980 }
00981 
00982 inline float
00983 trainer::get_accuracy_on_training () const
00984 {
00985   return accuracy_on_training;
00986 }
00987 
00988 inline float
00989 trainer::get_accuracy_on_validation () const
00990 {
00991   return accuracy_on_validation;
00992 }
00993 
00994 inline float
00995 trainer::get_accuracy_on_certification () const
00996 {
00997   return accuracy_on_certification;
00998 }
00999 
01000 inline
01001 int trainer::get_no_of_inputs () const {
01002   return ninput;
01003 }
01004 
01005 
01006 inline
01007 int trainer::get_no_of_outputs () const {
01008   return noutput;
01009 }
01010   
01011 inline
01012 iomanage* trainer::get_iomanager() const {
01013   return iomanager;
01014 }
01015 
01016 inline
01017 string trainer::get_error_filename() const {
01018   return error_filename;
01019 }
01020 
01021 inline
01022 string trainer::get_accuracy_filename() const {
01023   return accuracy_filename;
01024 }
01025 
01026 
01027 
01028 
01029 inline
01030 void trainer::set_best_epoch(int epoch) {
01031   _best_epoch = epoch;
01032 }
01033   
01034 
01035 inline
01036 void trainer::set_best_error_on_validation(float error) {
01037   best_error_on_validation=error;
01038 }
01039 
01040 inline
01041 int trainer::get_current_epoch() {
01042   return epoch;
01043 }
01044 
01045 inline
01046 void trainer::start_at_epoch(int e) {
01047   epoch = e;
01048 }
01049 
01050 inline
01051 short int trainer::get_stopping_cause() const {
01052   return stopping_cause;
01053 }
01054 
01055 inline
01056 void trainer::set_stopping_cause(short int cause) {
01057   stopping_cause= cause;
01058 }
01059 
01060 inline
01061 network* trainer::get_best_net() const {
01062   return best_net;
01063 }
01064 
01065 inline 
01066 void trainer::set_best_net(network* best) {
01067   best_net = best;
01068 }
01069 
01070 
01071 inline
01072 float trainer::get_wanted_accuracy() {
01073   return wanted_accuracy;
01074 }
01075 
01076 #endif
01077 
01078 
01079 
01080 

Generated on Tue Oct 12 00:32:11 2004 for Lightweight Neural Network ++ by  doxygen 1.3.9