train

UNet.train(*args, **kwargs)

Train the network.

This trains the network for the given number of epochs using the provided training and validation data.

If desired, the training can be augmented using adversarial training. In this case the network is additionally trained with an adversarial batch of examples in each step of the training.

Parameters:
  • training_data – pytorch dataloader providing the training data

  • validation_data – pytorch dataloader providing the validation data

  • n_epochs – the number of epochs to train the network for

  • adversarial_training – whether or not to use adversarial training

  • eps_adv – The scaling factor to use for adversarial training.