Base Classes and Functions¶
- training.train_model(model: NetBase, nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, loader_tuple: tuple[DataLoader, DataLoader], params: GeneratorType | str = None, print_bool: bool = True, remove_bool: bool = True, comment: str = None)[source]¶
Function for training a Pytorch
NetBase
model. For training PytorchLightning models, checklightning_objects
.This function will create a folder with reports about the training and the model, according to the information contained in the model’s report manager (see
ReportManager
) for more.The training can be stopped at any point by the user through a
KeyboardInterrupt
, which will conclude the training correctly, without losing progress.- Parameters:
model – The actual network to be trained.
nb_epochs – Number of epochs for training.
learning_rate – The learning rate.
loss_fn – The Loss Function to use.
loader_tuple – (training, validation) DataLoader objects.
params – Either the specific parameters that should be trained, if different from the whole network, or the
params
argument for the model’sget_params()
method. The latter use is recommended, since in this case theparams
information will be able to be integrated to the training report.print_bool – If we want to print training and validation errors during iterations.
remove_bool – If we want to delete the Epoch files after the training is complete.
comment – A comment to add to the training report.
- training.report(cls_instance) str [source]¶
Function for creating a report about an object.
Returns a string that should describe all of the information needed to initialize an identical class. It should also contain complementary information that may be useful to the user. This function in itself only calls the class’s
report
method, but, if none is defined, it will return a string with information about the object.The classes’
report
method shouldn’t have any arguments.- Parameters:
cls_instance – The class instance we want to make a report of.
- Returns:
A string with the report.
- class training.NetBase(*args, **kwargs)[source]¶
Base class for neural networks.
This class’s definition includes general methods that should be inherited or redefined by the neural networks. It is also recommended that the child classes pass their
*args
and**kwargs
when callingsuper
.- args¶
The model’s initialization
*args
.
- kwargs¶
The model’s initialization
**kwargs
.
- is_classifier¶
A
bool
stating whether or not the model is a classifier. This is useful for determining whether or not to calculate the model’s accuracy and a few other things during training. It defaults to False to avoid raising errors intrain_model()
. It can, of course, be overridden either by the child class or by passingis_classifier=True
as a keyword argument.
- trainclass¶
The model’s associated
TrainingClass
, which stores information about its training and report manager.
- manager¶
A shortcut for
self.trainclass.manager
, which is the model’sReportManager
, that stores information about the report folder where the model’s information is saved.Note
Why define the model’s manager inside of its TrainingClass?
This is because a few of
TrainingClass
’s methods use paths that are stored in the model’sReportManager
. If the ReportManager was outside it, these paths would have to be redefined inside of the training class, which is redundant, and may pose a problem if the user ever decides or has to change the report folder paths.
- report() str [source]¶
Creates a little report about the network.
Returns a string that should describe all of the information needed to reproduce the neural network’s initialization, as well as its representation (
.__repr__
) which is defined by PyTorch’snn.Module
.This function shouldn’t have any arguments and can be called through the
report()
function.
- reset()[source]¶
Resets the neural network’s parameters. This method can be useful when training the same network multiple times.
- get_params(params: str) GeneratorType [source]¶
Placeholder method for fetching a subset of the net’s parameters.
It should take an argument that can specify which parameters to select and return a
GeneratorType
object. See the example below for an example of implementation.Examples
class ConvModel(NetBase): def __init__(self, **kwargs): super(ConvModel, self).__init__(**kwargs) # net 1 self.conv_net_1 = nn.Sequential( nn.Conv2d(1, 32, (3, 3)), nn.MaxPool2d((2, 2), (2, 2)) ) self.lin_net_1 = nn.Sequential( nn.Linear(32 * 15 * 15, 369), nn.Softmax(dim=1) ) # net 2 self.conv_net_2 = nn.Sequential( nn.Conv2d(1, 64, (3, 3)), nn.MaxPool2d((2, 2), (2, 2)) ) self.lin_net_2 = nn.Sequential( nn.Linear(64 * 15 * 15, 369), nn.Softmax(dim=1) ) def get_params(self, params: str = None) -> GeneratorType: if params == '1': for parameter in self.conv_net_1.parameters(): yield parameter for parameter in self.lin_net_1.parameters(): yield parameter elif params == '2': for parameter in self.conv_net_2.parameters(): yield parameter for parameter in self.lin_net_2.parameters(): yield parameter else: return self.parameters()
- Parameters:
params – A value that can help specify which parameters to select.
- Returns:
A GeneratorType object with the parameters.
- class training.TrainingClass(delete_state_dicts: int = 1, **report_manager_kwargs)[source]¶
Class for storing training information.
This class was created with the intent of storing multiple training information in case the model is fine-tuned (and thus trained multiple times, possibly with different data and training conditions).
- train_list¶
List of
TrainingClass.TrainData
objects with the data for each individual training.
- nb_epochs¶
Total number of epochs completed (int).
- train_error¶
Training Error for each epoch (Tensor).
- valid_error¶
Validation Error for each epoch (Tensor).
- train_accur¶
Training Accuracy for each epoch (Tensor).
- valid_accur¶
Validation Accuracy for each epoch (Tensor).
- train_epoch_idx¶
The index of the epochs reported in the train_error and train_accur tensors. This has been added because the pytorch lightning implementation may lead to float values (at least in the case of the valid_epoch_idx).
- valid_epoch_idx¶
The index of the epochs reported in the valid_error and valid_accur tensors. This has been added because the pytorch lightning implementation may lead to float values.
- best_epoch¶
The number of the best epoch (int).
- best_state¶
The best state_dict out of all training epochs (collections.OrderedDict).
- delete_state_dicts¶
Whether to keep state_dicts from previous training loops. Its value indicates how many trainings back we delete. For example: 0 means we don’t delete any; 1 we delete the last (meaning we won’t keep any); 2 means we delete the second last (we will always keep the last, most recent); 3 means we delete the third last (we’ll keep the two most recent) and so on. This is defined by the following lines of code, from this class’s
finish_training()
method:if del_dicts: try: self.train_list[-del_dicts].state_dict_list = [] except IndexError: pass
Attention
Note that this will not delete the
TrainData
class for the respective training loop, all of its information is safe. It must also be kept in mind that picking a value \(x\) greater than one means that all epochs from the \(x-1\) most recent training loops will be kept, not only the last \(x-1\) epochs from the most recent training loop.
- manager¶
The
ReportManager
class that manages the model’s report folder.
- add_training(nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, file_name: str, loader_tuple: tuple[DataLoader], comment: str = None)[source]¶
- Adds a
TrainingClass.TrainData
object to the train_list attribute.
- Parameters:
nb_epochs – Number of epochs for training.
learning_rate – Learning Rate.
loss_fn – The Loss Function used.
file_name – The file_name picked for the training.
warning:: (..) –
file_name
might become deprecated in the future.loader_tuple – (training, validation) DataLoader objects.
- Adds a
- finish_training(remove_bool: bool = False, plot_accuracy: bool = False)[source]¶
Method for adapting the class’s attributes after training.
This adds the last training’s data to the rest of the training data, updating the best epoch and state_dict.
- plot(save_bool: bool = False, block: bool = False, var: str = 'error')[source]¶
Plots training or accuracy graphs for all trainings.
- Parameters:
save_bool – Whether to save the graph as a file.
block – Whether the plotting of the graph should stop the code from continuing.
var – Which graph to plot (“error” for the error or “accur” for the accuracy).
- report() str [source]¶
Creates a report that describes overall training.
- Returns:
A string with a little report, with information about the trainings.
- class TrainData(nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, file_name: str)[source]¶
An object for storing information of one particular training loop.
Note
Some of the attributes listed below are defined during initialization as lists, but become Tensors when the
finish_training()
method is called.- Parameters:
nb_epochs – Number of epochs for training.
learning_rate – Learning Rate.
loss_fn – The Loss Function used.
file_name – The file_name picked for the training.
warning:: (..) –
file_name
might become deprecated in the future.
- train_error¶
Training Error for each epoch (Tensor).
- valid_error¶
Validation Error for each epoch (Tensor).
- train_accur¶
Training Accuracy for each epoch (Tensor).
- valid_accur¶
Validation Accuracy for each epoch (Tensor).
- train_epoch_idx¶
The index of the epochs reported in the train_error and train_accur tensors. This has been added because the pytorch lightning implementation may lead to float values (at least in the case of the valid_epoch_idx).
- valid_epoch_idx¶
The index of the epochs reported in the valid_error and valid_accur tensors. This has been added because the pytorch lightning implementation may lead to float values.
- state_dict_list¶
A list with the state_dict of each epoch. It may be reset after training if the model’s associated
TrainingClass
has adelete_state_dicts
attribute with a value different from 0. Check the class’s documentation for more information on how this works.
- comment¶
A string with a possible comment to be added to the training report (through the
add_comment()
method). In particular, when the training is pruned by the user through aKeyboardInterrupt
.
- best_epoch¶
The number of the best epoch (int).
- best_state¶
The best state_dict (collections.OrderedDict).
- add_valid_epoch(error: Tensor, accur: Tensor, comment: str = None, epoch_nb: int = None) None [source]¶
Method used to add validation epoch loss information.
This is called by the pytorch lightning custom model’s
lightning_objects.LitConvNet.on_validation_epoch_end()
.- Parameters:
error – validation error.
accur – validation accuracy.
comment – A string with a comment to be added to the training report.
epoch_nb – The epoch number, which may be a float if the lightning’s Trainer flag val_check_interval is a float.
- add_epoch(error_tuple: tuple[Tensor, Tensor], accur_tuple: tuple[Tensor, Tensor], state_dict: OrderedDict, comment: str = None) None [source]¶
Method for adding a training + validation epoch’s information to the class.
This is used by the
train_model()
function for the training of pytorch’s Module models. For pytorch lightning models, theadd_train_epoch()
andadd_valid_epoch()
methods will be called directly.- Parameters:
error_tuple – training and validation errors.
accur_tuple – training and validation accuracies.
state_dict – The model’s current state_dict.
comment – A string with a comment to be added to the training report.
- finish_training()[source]¶
Method for adapting the class’s attributes after training.
This turns training and validation errors and accuracies into a tensor (originally lists). As well as defining the best epoch in this training and separating its state dict.
- class training.ReportManager(dirname: str = None, report_dir: str = '_Reports', complete_path: str = None)[source]¶
Class for managing a model’s report folder and report files.
This class is instantiated when a
NetBase
class is initialized. It’s stored in the model’sTrainingClass
class (trainclass
attribute), although it can be called through the model’smanager
attribute (property).The class creates the model’s report folder, where information about its training will be stored (such as its epochs’ state dictionaries, images, reports and the best iteration’s model), and manages the actual text reports.
- Parameters:
dirname –
The name of the model’s report directory. Defaults to the current date and time (
datetime.today()
) if nothing is passed.If the specified dirname already exists, the current date and time will be added to its end to distinguish the reports. The class will print a note on that.
report_dir – The name of the Reports folder (where the individual model report directories are stored). Default is
"_Report"
.complete_path –
If the user wants to use a base folder different from the current one, it can specify a path and the report manager will enter it. The model’s reports will be stored in
complete_path/report_dir/dirname
. If nothing is specified, the current directory is used as the complete path.Warning
Picking a different
complete_path
will result in a change of the current directory (to the specifiedcomplete_path
) which is not reverted.
- report_dir¶
The specified report_dir
- base_path¶
The specified complete_path (if none is specified, then it’s set as the current directory).
- dirname¶
The final dirname picked for the model’s report folder.
- path¶
f”.//{report_dir}/{dirname}”
- files¶
A set containing the text report files’ names. It is a set because during multiple trainings, the same files may be edited multiple times, which would result in multiple entries with the same name.
- chdir(path: str)[source]¶
Changes the report directory according to the informed final path.
This method may create the necessary directories to reach the path given by the user. If the number of directories that need to be created is greater than two, then the user will be prompted for confirmation.
- Parameters:
path – The path to the final model report folder.
- class File(filename, method, base_dir, dirname)[source]¶
Class for creating text report files when entered.
This class is sneakily entered when
ReportManager
is called.>>> manager = ReportManager(dirname='ConvModel') >>> with manager("Report.txt", 'w') as f: ... # It looks like we are using manager.__enter__ because ... # of the 'while' statement. However, ReportManager doesn't ... # have a .__enter__ method defined. ... # manager("Report.txt", 'w') is actually ... # manager.__call__("Report.txt", 'w') which actually returns ... # ReportManager.File object ... # which is then entered because of the 'with' statement ... f.write("Hello There") ...
Check the examples below. It then creates the desired text report file directly in the model’s report directory.
Examples
>>> manager = ReportManager(dirname='ConvModel') >>> with manager("Report.txt", 'w') as f: ... f.write("Hello There") ... >>> # The "Report.txt" file was written directly in >>> # './/_Reports/ConvModel' >>> # Another example: >>> model = NetBase(dirname='ConvModel') # Toy model Report "ConvModel" already exists, creating a new name: ConvModel - 2021-11-10 12_43_10.517134 >>> with model.manager("Report.txt", 'w') as f: ... f.write("General Kenobi") ... >>> # "Report.txt" written directly at >>> # './/_Reports/ConvModel - 2021-11-10 12_43_10.517134'