lr_grid_search
- lr_grid_search(data, model, loss_fun, validation_loss_fun, lr_grid=(0.1, 0.01, 0.005, 0.001), num_epochs=10, runs=1, verbose=True)[source]
grid search over learning rate values
- Parameters:
data – input data
model – model to train
loss_fun – training loss takes model and data as input
validation_loss_fun – function to compute validation loss input: (model, data)
lr_grid – learning rate values to try
num_epochs – number of epochs for training
runs – number of training runs to average over for selecting best learning rate
verbose – if
True
, output training progress
- Returns:
best learning rate, validation loss for all runs