Base Classes and Functions

training.train_model(model: NetBase, nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, loader_tuple: tuple[DataLoader, DataLoader], params: GeneratorType | str = None, print_bool: bool = True, remove_bool: bool = True, comment: str = None)[source]

Function for training a Pytorch NetBase model. For training PytorchLightning models, check lightning_objects.

This function will create a folder with reports about the training and the model, according to the information contained in the model’s report manager (see ReportManager) for more.

The training can be stopped at any point by the user through a KeyboardInterrupt, which will conclude the training correctly, without losing progress.

Parameters:
  • model – The actual network to be trained.

  • nb_epochs – Number of epochs for training.

  • learning_rate – The learning rate.

  • loss_fn – The Loss Function to use.

  • loader_tuple – (training, validation) DataLoader objects.

  • params – Either the specific parameters that should be trained, if different from the whole network, or the params argument for the model’s get_params() method. The latter use is recommended, since in this case the params information will be able to be integrated to the training report.

  • print_bool – If we want to print training and validation errors during iterations.

  • remove_bool – If we want to delete the Epoch files after the training is complete.

  • comment – A comment to add to the training report.

training.report(cls_instance) str[source]

Function for creating a report about an object.

Returns a string that should describe all of the information needed to initialize an identical class. It should also contain complementary information that may be useful to the user. This function in itself only calls the class’s report method, but, if none is defined, it will return a string with information about the object.

The classes’ report method shouldn’t have any arguments.

Parameters:

cls_instance – The class instance we want to make a report of.

Returns:

A string with the report.

class training.NetBase(*args, **kwargs)[source]

Base class for neural networks.

This class’s definition includes general methods that should be inherited or redefined by the neural networks. It is also recommended that the child classes pass their *args and **kwargs when calling super.

args

The model’s initialization *args.

kwargs

The model’s initialization **kwargs.

is_classifier

A bool stating whether or not the model is a classifier. This is useful for determining whether or not to calculate the model’s accuracy and a few other things during training. It defaults to False to avoid raising errors in train_model(). It can, of course, be overridden either by the child class or by passing is_classifier=True as a keyword argument.

trainclass

The model’s associated TrainingClass, which stores information about its training and report manager.

manager

A shortcut for self.trainclass.manager, which is the model’s ReportManager, that stores information about the report folder where the model’s information is saved.

Note

Why define the model’s manager inside of its TrainingClass?

This is because a few of TrainingClass’s methods use paths that are stored in the model’s ReportManager. If the ReportManager was outside it, these paths would have to be redefined inside of the training class, which is redundant, and may pose a problem if the user ever decides or has to change the report folder paths.

report() str[source]

Creates a little report about the network.

Returns a string that should describe all of the information needed to reproduce the neural network’s initialization, as well as its representation (.__repr__) which is defined by PyTorch’s nn.Module.

This function shouldn’t have any arguments and can be called through the report() function.

reset()[source]

Resets the neural network’s parameters. This method can be useful when training the same network multiple times.

forward(inputs: Tensor) Tensor[source]

The model’s forward pass.

get_params(params: str) GeneratorType[source]

Placeholder method for fetching a subset of the net’s parameters.

It should take an argument that can specify which parameters to select and return a GeneratorType object. See the example below for an example of implementation.

Examples

class ConvModel(NetBase):
    def __init__(self, **kwargs):
        super(ConvModel, self).__init__(**kwargs)

        # net 1
        self.conv_net_1 = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3)),
            nn.MaxPool2d((2, 2), (2, 2))
        )
        self.lin_net_1 = nn.Sequential(
            nn.Linear(32 * 15 * 15, 369),
            nn.Softmax(dim=1)
        )

        # net 2
        self.conv_net_2 = nn.Sequential(
            nn.Conv2d(1, 64, (3, 3)),
            nn.MaxPool2d((2, 2), (2, 2))
        )
        self.lin_net_2 = nn.Sequential(
            nn.Linear(64 * 15 * 15, 369),
            nn.Softmax(dim=1)
        )

    def get_params(self, params: str = None) -> GeneratorType:
        if params == '1':
            for parameter in self.conv_net_1.parameters():
                yield parameter
            for parameter in self.lin_net_1.parameters():
                yield parameter
        elif params == '2':
            for parameter in self.conv_net_2.parameters():
                yield parameter
            for parameter in self.lin_net_2.parameters():
                yield parameter
        else:
            return self.parameters()
Parameters:

params – A value that can help specify which parameters to select.

Returns:

A GeneratorType object with the parameters.

class training.TrainingClass(delete_state_dicts: int = 1, **report_manager_kwargs)[source]

Class for storing training information.

This class was created with the intent of storing multiple training information in case the model is fine-tuned (and thus trained multiple times, possibly with different data and training conditions).

train_list

List of TrainingClass.TrainData objects with the data for each individual training.

nb_epochs

Total number of epochs completed (int).

train_error

Training Error for each epoch (Tensor).

valid_error

Validation Error for each epoch (Tensor).

train_accur

Training Accuracy for each epoch (Tensor).

valid_accur

Validation Accuracy for each epoch (Tensor).

train_epoch_idx

The index of the epochs reported in the train_error and train_accur tensors. This has been added because the pytorch lightning implementation may lead to float values (at least in the case of the valid_epoch_idx).

valid_epoch_idx

The index of the epochs reported in the valid_error and valid_accur tensors. This has been added because the pytorch lightning implementation may lead to float values.

best_epoch

The number of the best epoch (int).

best_state

The best state_dict out of all training epochs (collections.OrderedDict).

delete_state_dicts

Whether to keep state_dicts from previous training loops. Its value indicates how many trainings back we delete. For example: 0 means we don’t delete any; 1 we delete the last (meaning we won’t keep any); 2 means we delete the second last (we will always keep the last, most recent); 3 means we delete the third last (we’ll keep the two most recent) and so on. This is defined by the following lines of code, from this class’s finish_training() method:

if del_dicts:
    try:
        self.train_list[-del_dicts].state_dict_list = []
    except IndexError:
        pass

Attention

Note that this will not delete the TrainData class for the respective training loop, all of its information is safe. It must also be kept in mind that picking a value \(x\) greater than one means that all epochs from the \(x-1\) most recent training loops will be kept, not only the last \(x-1\) epochs from the most recent training loop.

manager

The ReportManager class that manages the model’s report folder.

add_training(nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, file_name: str, loader_tuple: tuple[DataLoader], comment: str = None)[source]
Adds a TrainingClass.TrainData object to

the train_list attribute.

Parameters:
  • nb_epochs – Number of epochs for training.

  • learning_rate – Learning Rate.

  • loss_fn – The Loss Function used.

  • file_name – The file_name picked for the training.

  • warning:: (..) – file_name might become deprecated in the future.

  • loader_tuple – (training, validation) DataLoader objects.

finish_training(remove_bool: bool = False, plot_accuracy: bool = False)[source]

Method for adapting the class’s attributes after training.

This adds the last training’s data to the rest of the training data, updating the best epoch and state_dict.

plot(save_bool: bool = False, block: bool = False, var: str = 'error')[source]

Plots training or accuracy graphs for all trainings.

Parameters:
  • save_bool – Whether to save the graph as a file.

  • block – Whether the plotting of the graph should stop the code from continuing.

  • var – Which graph to plot (“error” for the error or “accur” for the accuracy).

report() str[source]

Creates a report that describes overall training.

Returns:

A string with a little report, with information about the trainings.

class TrainData(nb_epochs: int, learning_rate: float, loss_fn: nn.Loss, file_name: str)[source]

An object for storing information of one particular training loop.

Note

Some of the attributes listed below are defined during initialization as lists, but become Tensors when the finish_training() method is called.

Parameters:
  • nb_epochs – Number of epochs for training.

  • learning_rate – Learning Rate.

  • loss_fn – The Loss Function used.

  • file_name – The file_name picked for the training.

  • warning:: (..) – file_name might become deprecated in the future.

train_error

Training Error for each epoch (Tensor).

valid_error

Validation Error for each epoch (Tensor).

train_accur

Training Accuracy for each epoch (Tensor).

valid_accur

Validation Accuracy for each epoch (Tensor).

train_epoch_idx

The index of the epochs reported in the train_error and train_accur tensors. This has been added because the pytorch lightning implementation may lead to float values (at least in the case of the valid_epoch_idx).

valid_epoch_idx

The index of the epochs reported in the valid_error and valid_accur tensors. This has been added because the pytorch lightning implementation may lead to float values.

state_dict_list

A list with the state_dict of each epoch. It may be reset after training if the model’s associated TrainingClass has a delete_state_dicts attribute with a value different from 0. Check the class’s documentation for more information on how this works.

comment

A string with a possible comment to be added to the training report (through the add_comment() method). In particular, when the training is pruned by the user through a KeyboardInterrupt.

best_epoch

The number of the best epoch (int).

best_state

The best state_dict (collections.OrderedDict).

add_valid_epoch(error: Tensor, accur: Tensor, comment: str = None, epoch_nb: int = None) None[source]

Method used to add validation epoch loss information.

This is called by the pytorch lightning custom model’s lightning_objects.LitConvNet.on_validation_epoch_end().

Parameters:
  • error – validation error.

  • accur – validation accuracy.

  • comment – A string with a comment to be added to the training report.

  • epoch_nb – The epoch number, which may be a float if the lightning’s Trainer flag val_check_interval is a float.

add_epoch(error_tuple: tuple[Tensor, Tensor], accur_tuple: tuple[Tensor, Tensor], state_dict: OrderedDict, comment: str = None) None[source]

Method for adding a training + validation epoch’s information to the class.

This is used by the train_model() function for the training of pytorch’s Module models. For pytorch lightning models, the add_train_epoch() and add_valid_epoch() methods will be called directly.

Parameters:
  • error_tuple – training and validation errors.

  • accur_tuple – training and validation accuracies.

  • state_dict – The model’s current state_dict.

  • comment – A string with a comment to be added to the training report.

finish_training()[source]

Method for adapting the class’s attributes after training.

This turns training and validation errors and accuracies into a tensor (originally lists). As well as defining the best epoch in this training and separating its state dict.

report() str[source]

Creates a report that describes the training.

Returns:

A string with a little report, with information about the training.

add_comment(comment: str)[source]

Adds a comment to the class’s comment attribute.

Parameters:

comment – The comment we want to add.

class training.ReportManager(dirname: str = None, report_dir: str = '_Reports', complete_path: str = None)[source]

Class for managing a model’s report folder and report files.

This class is instantiated when a NetBase class is initialized. It’s stored in the model’s TrainingClass class (trainclass attribute), although it can be called through the model’s manager attribute (property).

The class creates the model’s report folder, where information about its training will be stored (such as its epochs’ state dictionaries, images, reports and the best iteration’s model), and manages the actual text reports.

Parameters:
  • dirname

    The name of the model’s report directory. Defaults to the current date and time (datetime.today()) if nothing is passed.

    If the specified dirname already exists, the current date and time will be added to its end to distinguish the reports. The class will print a note on that.

  • report_dir – The name of the Reports folder (where the individual model report directories are stored). Default is "_Report".

  • complete_path

    If the user wants to use a base folder different from the current one, it can specify a path and the report manager will enter it. The model’s reports will be stored in complete_path/report_dir/dirname. If nothing is specified, the current directory is used as the complete path.

    Warning

    Picking a different complete_path will result in a change of the current directory (to the specified complete_path) which is not reverted.

report_dir

The specified report_dir

base_path

The specified complete_path (if none is specified, then it’s set as the current directory).

dirname

The final dirname picked for the model’s report folder.

path

f”.//{report_dir}/{dirname}”

files

A set containing the text report files’ names. It is a set because during multiple trainings, the same files may be edited multiple times, which would result in multiple entries with the same name.

chdir(path: str)[source]

Changes the report directory according to the informed final path.

This method may create the necessary directories to reach the path given by the user. If the number of directories that need to be created is greater than two, then the user will be prompted for confirmation.

Parameters:

path – The path to the final model report folder.

remove_epochs()[source]

Removes the epoch state-dicts in the Epochs folder after training.

class File(filename, method, base_dir, dirname)[source]

Class for creating text report files when entered.

This class is sneakily entered when ReportManager is called.

>>> manager = ReportManager(dirname='ConvModel')
>>> with manager("Report.txt", 'w') as f:
...     # It looks like we are using manager.__enter__ because
...     # of the 'while' statement. However, ReportManager doesn't
...     # have a .__enter__ method defined.
...     # manager("Report.txt", 'w') is actually
...     # manager.__call__("Report.txt", 'w') which actually returns
...     # ReportManager.File object
...     # which is then entered because of the 'with' statement
...     f.write("Hello There")
...

Check the examples below. It then creates the desired text report file directly in the model’s report directory.

Examples

>>> manager = ReportManager(dirname='ConvModel')
>>> with manager("Report.txt", 'w') as f:
...     f.write("Hello There")
...
>>> # The "Report.txt" file was written directly in
>>> # './/_Reports/ConvModel'
>>> # Another example:
>>> model = NetBase(dirname='ConvModel')  # Toy model
Report "ConvModel" already exists, creating a new name: ConvModel - 2021-11-10 12_43_10.517134
>>> with model.manager("Report.txt", 'w') as f:
...     f.write("General Kenobi")
...
>>> # "Report.txt" written directly at
>>> # './/_Reports/ConvModel - 2021-11-10 12_43_10.517134'
test_path()[source]

Tests if the model’s report directory can be accessed.

If it can’t, it will prompt the user with a few questions in order to find out whether or not the problem can be fixed.