Generic Training

Deep.Net contains a powerful, generic function to train your model. Together with the dataset handler it provides the following functionality:

• initialization of the model's parameters
• mini-batch training
• logging of losses on the training, validation and test sets
• automatic scheduling of the learning rate
• termination of training when
• a desired validation loss is reached
• a set number of iterations have been performed
• there is no loss improvement on the validation set within a set number of iterations
• checkpointing allows the training state to be saved to disk and training to be restarted afterwards (useful when running on non-reliable hardware or on a compute cluster that pauses jobs or moves them around on the cluster's nodes)

Example model

To demonstrate its use we return to our two-layer neural network model for classifying MNIST digits.

 1: 2: 3: 4: 5:  open SymTensor open SymTensor.Compiler.Cuda open Models open Datasets open Optimizers 

We load the MNIST dataset using the Mnist.load function using a validation to training ratio of 0.1.

 1: 2:  let mnist = Mnist.load (__SOURCE_DIRECTORY__ + "../../../Data/MNIST") 0.1 |> TrnValTst.ToCuda 

Next, we define and instantiate a model using the MLP (multi-layer perceptron, i.e. multi-layer neural network) component.

  1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27:  let mb = ModelBuilder "NeuralNetModel" // define symbolic sizes let nBatch = mb.Size "nBatch" let nInput = mb.Size "nInput" let nClass = mb.Size "nClass" let nHidden = mb.Size "nHidden" // define model parameters let mlp = MLP.pars (mb.Module "MLP") { Layers = [{NInput=nInput; NOutput=nHidden; TransferFunc=NeuralLayer.Tanh} {NInput=nHidden; NOutput=nClass; TransferFunc=NeuralLayer.SoftMax}] LossMeasure = LossLayer.CrossEntropy } // define variables let input : ExprT = mb.Var "Input" [nBatch; nInput] let target : ExprT = mb.Var "Target" [nBatch; nClass] // instantiate model mb.SetSize nInput mnist.Trn.[0].Img.Shape.[0] mb.SetSize nClass mnist.Trn.[0].Lbl.Shape.[0] mb.SetSize nHidden 100 let mi = mb.Instantiate DevCuda // loss expression let loss = MLP.loss mlp input.T target.T 

Note that the input and target matrices must be transposed, since the neural network model expects each sample to be a column in the matrix while the dataset provides a matrix where each row is a sample.

We instantiate the Adam optimizer to minimize the loss and use its default configuration.

 1: 2: 3:  // optimizer let opt = Adam (loss, mi.ParameterVector, DevCuda) let optCfg = opt.DefaultCfg 

In previous example we have written a simple optimization loop by hand. Here instead, we will employ the generic training function provided by Deep.Net.

Defining a Trainable

The generic training function works on any object that implements the Train.ITrainable<'Smpl, 'T> interface where 'Smpl is a sample record type (see dataset handling) and 'T is the data type of the model parameters, e.g. single. The easiest way to create an ITrainable from a symbolic loss expression is to use the Train.trainableFromLossExpr function. This function has the signature

 1: 2: 3: 4: 5: 6:  val trainableFromLossExpr : modelInstance:ModelInstance<'T> -> loss:ExprT<'T> -> varEnvBuilder:('Smpl -> VarEnvT) -> optimizer:IOptimizer<'T,'OptCfg,'OptState> -> optCfg:'OptCfg -> ITrainable<'Smpl,'T> 

The arguments have the following meaning.

• modelInstance is the model instance containing the parameters of the model to be trained.
• loss is the loss expression to be minimized.
• varEnvBuilder is a user-provided function that takes an instance of user-provided type 'Smpl and returns a variable environment to evaluate the loss expression on this sample(s). The sample below shows how to build a variable environment from a sample.
• optimizer is an instance of an optimizer. All optimizers in Deep.Net implement the IOptimizer interface.
• optCfg is the optimizer configuration to use. The learning rate in the specified optimizer configuration will be overwritten.

Let us build a trainable for our model. First, we need to define a function that creates a variable environment from a sample.

 1: 2: 3: 4:  let smplVarEnv (smpl: MnistT) = VarEnv.empty |> VarEnv.add input smpl.Img |> VarEnv.add target smpl.Lbl 

The value of the symbolic variable input is set to the image of the MNIST sample and the symbolic variable target is set to the label in one-hot encoding.

We are now ready to construct the trainable.

 1: 2:  let trainable = Train.trainableFromLossExpr mi loss smplVarEnv opt optCfg 

Training configuration

Next, we need to specify the training configuration using the Train.Cfg record type. For illustration purposes we write down the whole record instance; in practice you would copy Train.defaultCfg and change fields as necessary.

  1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13:  let trainCfg : Train.Cfg = { Seed = 100 BatchSize = 10000 LossRecordInterval = 10 Termination = Train.ItersWithoutImprovement 100 MinImprovement = 1e-7 TargetLoss = None MinIters = Some 100 MaxIters = None LearningRates = [1e-3; 1e-4; 1e-5] CheckpointDir = None DiscardCheckpoint = false } 

The meaning of the fields is as follows.

• Seed is the random seed for model parameter initialization.
• BatchSize is the size of mini-batches used for training and evaluating the losses.
• LossRecordInterval is the number of iterations to perform between evaluating the loss on the validation and test sets.
• Termination is the termination criterium and can have the following values:
• Train.ItersWithImprovements cnt to stop training after cnt iteraitons without improvement.
• Train.IterGain gain to train for $$\mathrm{gain} \cdot \mathrm{bestIter}$$ iterations where $$\mathrm{bestIter}$$ is the best iteration. Usually one would use $$\mathrm{gain} \approx 2.0$$.
• Train.Forever disables the termination criterium.
• MinImprovement is the minimum loss change to count as improvement and should be a small number.
• TargetLoss can be used to specify a target validation loss that stops training when achieved. Use Some loss or None.
• MinIters can be the minimum number of training iterations to perform in the form Some iters, or None.
• MaxIters can be a hard limit on the training iterations in the form Some iters, or None.
• LearningRates is a list of learning rates to use. Training starts with the first element and moves to the next one, when the termination criterium (specified by the field Termination) is triggered.
• CheckpointDir may specify a directory in the form Some dir. (see checkpoint section for details)
• DiscardCheckpoint prohibits loading of a checkpoint if it is true.

Performing the training

Now training can be performed by calling the Train.train function. It takes three arguments: a trainable, the dataset to use and the training configuration. The dataset was already loaded above.

 1:  let result = Train.train trainable mnist trainCfg 

This will produce output similar to

  1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34:  Initializing model parameters for training Training with Dataset (54000 training, 6000 validation, 10000 test Datasets.MnistTs) Using learning rate 0.001 10: trn= 0.5739 val= 0.4652 tst= 0.5173 20: trn= 0.3686 val= 0.3210 tst= 0.3583 ... 380: trn= 0.0155 val= 0.1083 tst= 0.1114 390: trn= 0.0146 val= 0.1082 tst= 0.1113 400: trn= 0.0137 val= 0.1082 tst= 0.1112 410: trn= 0.0129 val= 0.1082 tst= 0.1112 420: trn= 0.0121 val= 0.1083 tst= 0.1113 430: trn= 0.0114 val= 0.1083 tst= 0.1113 440: trn= 0.0108 val= 0.1084 tst= 0.1114 450: trn= 0.0102 val= 0.1085 tst= 0.1115 460: trn= 0.0096 val= 0.1086 tst= 0.1116 470: trn= 0.0091 val= 0.1087 tst= 0.1118 480: trn= 0.0086 val= 0.1089 tst= 0.1120 490: trn= 0.0081 val= 0.1090 tst= 0.1121 500: trn= 0.0077 val= 0.1092 tst= 0.1123 510: trn= 0.0073 val= 0.1093 tst= 0.1125 Trained for 110 iterations without improvement Using learning rate 0.0001 410: trn= 0.0135 val= 0.1082 tst= 0.1112 420: trn= 0.0134 val= 0.1082 tst= 0.1112 ... 510: trn= 0.0123 val= 0.1083 tst= 0.1113 Trained for 110 iterations without improvement Using learning rate 1e-05 410: trn= 0.0136 val= 0.1082 tst= 0.1112 420: trn= 0.0136 val= 0.1082 tst= 0.1112 ... 510: trn= 0.0134 val= 0.1082 tst= 0.1112 Trained for 110 iterations without improvement Training completed after 400 iterations in 00:30:07.6179551 because NoImprovement 

While training is executed you can press the q key to stop training immediately and the d key to switch to the next learning rate specified in the configuration.

During training the parameters that produce the best validation loss are saved each time the losses are evaluated (as set by the LossRecordInterval field in the training configuration). When the validation loss does not improve for the set number of iterations (field Termination in the training configuration), the best parameters are restored and the next learning rate (field LearningRates) from the configuration is used. This explains why the iteration count resets by 100 steps, each time the loss stops improving.

The best validation lost is achieved around iteration 400, then the model starts to overfit. Decreasing the learning rate does not help in this case, thus training is terminated after exhausting the list of learning rates.

Training result and log

The return value of Train.train is a record of type TrainingResult that contains the training results and the training log.

 1: 2: 3:  printfn "Termination reason is %A after %A" result.TerminationReason result.Duration printfn "The best iteration is \n%A" result.Best printfn "The training log consists of %d entries." (List.length result.History) 

This prints

 1: 2: 3: 4: 5: 6: 7: 8:  Termination reason is NoImprovement after 00:29:28.1679299 The best iteration is {Iter = 400; TrnLoss = 0.01370835087; ValLoss = 0.1082176194; TstLoss = 0.1112449616; LearningRate = 0.001;} The training log consists of 51 entries. 

It is possible to save the training result as a JSON file by calling result.Save. This is useful when you use software or scripts to gather and analyze the results of multiple experiments.

Checkpointing

Checkpoint allows to training process to be interrupted and resumed later. To enable checkpoint support, set the CheckpointDir of the configuration record to some suitable directory. This directory has to be unique for each process.

When checkpoint support is enabled, the training functions traps the CTRL+C and CTRL+BREAK signals. When such a signal is received, the training state (including the model parameters) is stored in the specified directory and the process is terminated with exit code 10. In this case, the training function does not return to the user code.

When the program is executed again and the training function is called, it checks for a valid checkpoint. If one is found, it is loaded and training resumes where it was interrupted.

To discard an existing checkpoint (for example if training or models parameters were changed), set DiscardCheckpoint to true. This will delete any existing checkpoints from disk and restart training from the beginning.

Summary

With the generic training function you can train any model that has a loss expression. The main effort is to write a small wrapper function that maps a training sample to a variable environment. Various termination criteria, common in machine learning, are implemented.

val mnist : obj

Full name: Training.mnist
val mb : obj

Full name: Training.mb
Multiple items
val single : value:'T -> single (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.single

--------------------
type single = System.Single

Full name: Microsoft.FSharp.Core.single
val nBatch : obj

Full name: Training.nBatch
val nInput : obj

Full name: Training.nInput
val nClass : obj

Full name: Training.nClass
val nHidden : obj

Full name: Training.nHidden
val mlp : obj

Full name: Training.mlp
val input : obj

Full name: Training.input
val target : obj

Full name: Training.target
val mi : obj

Full name: Training.mi
val loss : obj

Full name: Training.loss
val opt : obj

Full name: Training.opt
val optCfg : obj

Full name: Training.optCfg
val smplVarEnv : smpl:'a -> 'b

Full name: Training.smplVarEnv
val smpl : 'a
val trainable : obj

Full name: Training.trainable
val trainCfg : obj

Full name: Training.trainCfg
union case Option.None: Option<'T>
union case Option.Some: Value: 'T -> Option<'T>
val result : obj

Full name: Training.result
val printfn : format:Printf.TextWriterFormat<'T> -> 'T

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.printfn
Multiple items
module List

from Microsoft.FSharp.Collections

--------------------
type List<'T> =
| ( [] )
| ( :: ) of Head: 'T * Tail: 'T list
interface IEnumerable
interface IEnumerable<'T>
member GetSlice : startIndex:int option * endIndex:int option -> 'T list
member IsEmpty : bool
member Item : index:int -> 'T with get
member Length : int
member Tail : 'T list
static member Cons : head:'T * tail:'T list -> 'T list
static member Empty : 'T list

Full name: Microsoft.FSharp.Collections.List<_>
val length : list:'T list -> int

Full name: Microsoft.FSharp.Collections.List.length
val log : value:'T -> 'T (requires member Log)

Full name: Microsoft.FSharp.Core.Operators.log