Visit the lwnnplus home page, download the technical report in english or italian.
00001 #ifndef TRAINER_H 00002 #define TRAINER_H 00003 /* 00004 * Lightweight Neural Net ++ - Trainer class 00005 * http://lwneuralnetplus.sourceforge.net/ 00006 * 00007 * This C++ library provides the class trainer wich implements a 00008 * the most used techniques for training a network 00009 * 00010 * By Lorenzo Masetti <lorenzo.masetti@libero.it> and Luca Cinti <lucacinti@supereva.it> 00011 * 00012 * This library is free software; you can redistribute it and/or 00013 * modify it under the terms of the GNU Lesser General Public 00014 * License as published by the Free Software Foundation; either 00015 * version 2.1 of the License, or (at your option) any later version. 00016 * 00017 */ 00018 #include <string> 00019 #include <stdexcept> 00020 #include <fstream> 00021 using namespace std; 00022 00023 #include "network.h" 00024 #include "iomanage.h" 00025 #include "iomanagelwnnfann.h" 00026 00046 class trainer 00047 { 00048 public: 00049 /* Constructors */ 00050 00058 trainer (network * net, const string & error_filename = 00059 "", const string & accuracy_filename = ""); 00060 00071 trainer (int inputs, int outputs, int hidden, 00072 const string & error_filename = 00073 "", const string & accuracy_filename = ""); 00074 00088 trainer (const string & training_file, int hidden = 00089 0, const string & error_filename = 00090 "", const string & accuracy_filename = "", 00091 iomanage* iomanager = NULL); 00092 00096 trainer (const trainer& t); 00097 00098 00108 virtual ~trainer (); 00109 00110 /* Accessors and Mutators */ 00111 00112 00121 int get_current_epoch(); 00122 00126 void start_at_epoch(int epoch); 00127 00132 void clear(); 00133 00134 00138 int get_max_epochs () const; 00139 00153 void set_max_epochs (int nepochs); 00154 00158 float get_min_error () const; 00159 00167 void set_min_error (float error); 00168 00181 bool set_wanted_accuracy( float accuracy); 00182 00183 00190 float get_wanted_accuracy(); 00191 00192 00198 void set_error_filename (const string & filename); 00199 00203 string get_error_filename() const; 00204 00205 00211 void set_accuracy_filename (const string & filename); 00212 00216 string get_accuracy_filename() const; 00217 00221 network *get_network () const; 00222 00232 bool set_network (network * newnet); 00233 00276 bool set_network_best(); 00277 00278 00286 void set_training_validation (bool train_valid); 00287 00296 bool get_training_validation () const; 00297 00303 void set_epochs_checking_error (int epochs); 00304 00310 int get_epochs_checking_error () const; 00311 00312 00319 int get_epochs_report () const; 00320 00327 void set_epochs_report (int e); 00328 00352 void set_stop_on_overfit (bool stop, int batches_increasing = 00353 0, float valid_higher = 0.0); 00354 00361 bool get_stop_on_overfit () const; 00362 00363 00367 float get_best_error () const; 00368 00372 int get_best_epoch () const; 00373 00380 short int get_accuracy_mode () const; 00381 00396 virtual void set_accuracy_mode (short int mode); 00397 00398 /* Errors and accuracies Accessors */ 00399 00404 float get_accuracy_on_training () const; 00405 00410 float get_accuracy_on_validation () const; 00411 00416 float get_accuracy_on_certification () const; 00417 00422 float get_error_on_training () const; 00423 00424 00429 float get_error_on_validation () const; 00430 00431 00436 float get_error_on_certification () const; 00437 00442 int get_no_of_inputs () const; 00443 00444 00449 int get_no_of_outputs () const; 00450 00462 short int get_stopping_cause() const; 00463 00470 virtual string get_stopping_cause_string() const; 00471 00472 00473 /* Methods to load training, validation, certification set */ 00474 00481 void load_training (const string & filename); 00482 00485 void load_validation (const string & filename); 00486 00489 void load_certification (const string & filename); 00490 00491 00492 00496 bool using_best_net () const; 00497 00511 int train_online (bool verbose); 00512 00513 00527 int train_batch (bool verbose=false); 00528 00529 00545 int train_ssab (bool verbose=false, bool reset_ssab = false); 00546 00562 int train_shuffle(bool verbose=false); 00563 00575 int train (int mode=0, bool verbose=false); 00576 00577 00585 bool test (); 00586 00593 bool validate (); 00594 00601 bool certificate (); 00602 00614 float show_results (bool verbose=false, int set = 1) const; 00615 00629 void set_iomanager(iomanage* iomanager); 00630 00639 iomanage* get_iomanager() const; 00640 00644 const trainer& operator= (const trainer& other); 00645 00647 static const short int NO_ACCURACY = 0; 00649 static const short int BINARY_ACCURACY = 1; 00650 00652 static const short int MAXGUESS_ACCURACY = 2; 00653 00655 static const short int STOPPING_UNDEFINED_CAUSE = 0; 00657 static const short int STOPPING_MAX_EPOCHS = 1; 00659 static const short int STOPPING_MIN_ERROR_REACHED_ON_TRAINING = 2; 00661 static const short int STOPPING_MIN_ERROR_REACHED_ON_BEST_ERROR = 3; 00663 static const short int STOPPING_OVERFIT = 4; 00665 static const short int STOPPING_MIN_ERROR_REACHED_ON_VALIDATION = 5; 00667 static const short int STOPPING_WANTED_ACCURACY_REACHED = 6; 00668 protected: 00677 virtual void print_input(float* input) const; 00678 00687 virtual void print_output(float* output) const; 00688 00708 virtual bool check (float *output, float *target) const; 00709 00727 virtual bool continue_training (int start); 00728 00729 /* These methods are declared protected because you might 00730 * need them in the implementation of continue_training() 00731 */ 00732 00739 void set_best_epoch(int epoch); 00740 00741 00748 void set_best_error_on_validation(float error); 00749 00756 void set_stopping_cause(short int cause); 00757 00764 network *get_best_net () const; 00765 00766 00773 void set_best_net(network* best); 00774 00775 private: 00776 /* Print accuracy on file */ 00777 void print_accuracy () const; 00778 void set_defaults (bool construct_iomanager = true); 00779 void destroy (int n, float **input, float **output); 00780 void allocate_data (int npattern, float **&input, float **&target); 00781 void verbose_report (); 00782 00783 float get_accuracy (int npatterns, float **input, float **target) const; 00784 float get_error (int npatterns, float **input, float **target) const; 00785 void compute_error_and_accuracy (float &error, float &accuracy, 00786 int npatterns, float **input, 00787 float **target); 00788 00789 float compute_accuracy(int npatterns, float** input, float** target); 00790 00791 void compute_accuracy_on_training(); 00792 00793 void reports (bool verbose); 00794 00795 float show_on_set (bool verbose,int npatterns, float **input, float **target) const; 00796 00797 static void print_vector(float* v, int n); 00798 void check_files(); 00799 00800 void copy_data(int n, float** input, float** target, float** srcinput, float** srctarget ); 00801 void copy(const trainer& t); 00802 void free_trainer(); 00803 00804 #ifdef DEBUG 00805 void print_input (float **, int np) const; 00806 void print_target (float **patt, int np) const; 00807 #endif 00808 00809 00810 /* Input length */ 00811 int ninput; 00812 /* Output length */ 00813 int noutput; 00814 /* Number of trainig patterns */ 00815 int npattern_training; 00816 /* Number of validation patterns */ 00817 int npattern_validation; 00818 /* Number of certification patterns */ 00819 int npattern_certification; 00820 00821 /* current epoch */ 00822 int epoch; 00823 00824 /* computing error on validation set? */ 00825 bool train_valid; 00826 /* each this number of epochs error is computed on validation set */ 00827 int epochs_checking_error; 00828 00829 /* epochs of report */ 00830 int epochs_report; 00831 00832 /* stop on overfit */ 00833 bool stop_on_overfit; 00834 /* how much higher must be the error on validation set for stopping ? */ 00835 float _higher_valid; 00836 short int accuracy_mode; 00837 00838 /* Wanted accuracy on validation set. Training stops if this accuracy is reached */ 00839 float wanted_accuracy; 00840 00841 00842 /* how many batches should pass with validation error increasing to 00843 * suppose overfitting? 00844 */ 00845 int _batches_not_saving; 00846 /* epoch of the best net */ 00847 int _best_epoch; 00848 00849 /* stopping cause */ 00850 short int stopping_cause; 00851 00852 float error_on_training; 00853 float error_on_validation; 00854 float best_error_on_validation; 00855 float error_on_certification; 00856 float accuracy_on_training; 00857 float accuracy_on_validation; 00858 float accuracy_on_certification; 00859 00860 /* Training set */ 00861 float **training_set_input; 00862 float **training_set_target; 00863 /* Validation set */ 00864 float **validation_set_input; 00865 float **validation_set_target; 00866 /* Certification set */ 00867 float **certification_set_input; 00868 float **certification_set_target; 00869 /* Max number of epochs */ 00870 int max_epochs; 00871 /* Min average error */ 00872 float min_error; 00873 /* File used to store error trend */ 00874 string error_filename; 00875 /* File used to store accuracy trend */ 00876 string accuracy_filename; 00877 /* Neural network to train */ 00878 network *net; 00879 00880 /* Best Neural Network */ 00881 network* best_net; 00882 00883 /* Iomanager used to read patterns */ 00884 iomanage* iomanager; 00885 }; 00886 00887 /**************************************** 00888 * IMPLEMENTATION OF INLINE FUNCTIONS 00889 * ACCESSORS AND MUTATORS 00890 ****************************************/ 00891 00892 inline float 00893 trainer::get_min_error () const 00894 { 00895 return min_error; 00896 } 00897 00898 00899 inline int 00900 trainer::get_max_epochs () const 00901 { 00902 return max_epochs; 00903 } 00904 00905 inline network * trainer::get_network () const 00906 { 00907 return net; 00908 } 00909 00910 inline void 00911 trainer::set_min_error (float error) 00912 { 00913 min_error = error; 00914 } 00915 00916 inline bool trainer::get_training_validation () const 00917 { 00918 return train_valid; 00919 } 00920 00921 inline bool trainer::using_best_net () const 00922 { 00923 return (best_net != NULL); 00924 } 00925 00926 00927 inline float 00928 trainer::get_best_error () const 00929 { 00930 return best_error_on_validation; 00931 } 00932 00933 00934 00935 00936 inline bool trainer::get_stop_on_overfit () const 00937 { 00938 return stop_on_overfit; 00939 } 00940 00941 inline int 00942 trainer::get_epochs_checking_error () const 00943 { 00944 return epochs_checking_error; 00945 } 00946 00947 inline 00948 int trainer::get_epochs_report () const { 00949 return epochs_report; 00950 } 00951 00952 inline int 00953 trainer::get_best_epoch () const 00954 { 00955 return _best_epoch; 00956 } 00957 00958 inline short int 00959 trainer::get_accuracy_mode () const 00960 { 00961 return accuracy_mode; 00962 } 00963 00964 inline float 00965 trainer::get_error_on_training () const 00966 { 00967 return error_on_training; 00968 } 00969 00970 inline float 00971 trainer::get_error_on_validation () const 00972 { 00973 return error_on_validation; 00974 } 00975 00976 inline float 00977 trainer::get_error_on_certification () const 00978 { 00979 return error_on_certification; 00980 } 00981 00982 inline float 00983 trainer::get_accuracy_on_training () const 00984 { 00985 return accuracy_on_training; 00986 } 00987 00988 inline float 00989 trainer::get_accuracy_on_validation () const 00990 { 00991 return accuracy_on_validation; 00992 } 00993 00994 inline float 00995 trainer::get_accuracy_on_certification () const 00996 { 00997 return accuracy_on_certification; 00998 } 00999 01000 inline 01001 int trainer::get_no_of_inputs () const { 01002 return ninput; 01003 } 01004 01005 01006 inline 01007 int trainer::get_no_of_outputs () const { 01008 return noutput; 01009 } 01010 01011 inline 01012 iomanage* trainer::get_iomanager() const { 01013 return iomanager; 01014 } 01015 01016 inline 01017 string trainer::get_error_filename() const { 01018 return error_filename; 01019 } 01020 01021 inline 01022 string trainer::get_accuracy_filename() const { 01023 return accuracy_filename; 01024 } 01025 01026 01027 01028 01029 inline 01030 void trainer::set_best_epoch(int epoch) { 01031 _best_epoch = epoch; 01032 } 01033 01034 01035 inline 01036 void trainer::set_best_error_on_validation(float error) { 01037 best_error_on_validation=error; 01038 } 01039 01040 inline 01041 int trainer::get_current_epoch() { 01042 return epoch; 01043 } 01044 01045 inline 01046 void trainer::start_at_epoch(int e) { 01047 epoch = e; 01048 } 01049 01050 inline 01051 short int trainer::get_stopping_cause() const { 01052 return stopping_cause; 01053 } 01054 01055 inline 01056 void trainer::set_stopping_cause(short int cause) { 01057 stopping_cause= cause; 01058 } 01059 01060 inline 01061 network* trainer::get_best_net() const { 01062 return best_net; 01063 } 01064 01065 inline 01066 void trainer::set_best_net(network* best) { 01067 best_net = best; 01068 } 01069 01070 01071 inline 01072 float trainer::get_wanted_accuracy() { 01073 return wanted_accuracy; 01074 } 01075 01076 #endif 01077 01078 01079 01080