local2global_embedding
1.0
  • Home

Contents

  • Reference
    • local2global_embedding
      • classfication
      • clustering
      • datasets
      • embedding
        • dgi
        • eval
        • gae
        • svd
        • train
      • network
      • outliers
      • patches
      • plot
      • progress
      • run
      • sparsify
      • utils

Index

  • Index
local2global_embedding
  • Reference
  • local2global_embedding
  • embedding
  • train
  • lr_grid_search
  • View page source

lr_grid_search

lr_grid_search(data, model, loss_fun, validation_loss_fun, lr_grid=(0.1, 0.01, 0.005, 0.001), num_epochs=10, runs=1, verbose=True)[source]

grid search over learning rate values

Parameters:
  • data – input data

  • model – model to train

  • loss_fun – training loss takes model and data as input

  • validation_loss_fun – function to compute validation loss input: (model, data)

  • lr_grid – learning rate values to try

  • num_epochs – number of epochs for training

  • runs – number of training runs to average over for selecting best learning rate

  • verbose – if True, output training progress

Returns:

best learning rate, validation loss for all runs

Previous Next

© Copyright 2021, Lucas G. S. Jeub.

Built with Sphinx using a theme provided by Read the Docs.