Skip to content

This flair_trainers() provides R users with access to Flair's ModelTrainer Python class using the reticulate package. The ModelTrainer class offers the following main methods:

  • train: Trains a given model. Parameters include the corpus (data split into training, development, and test sets), an output directory to save the model and logs, and various other parameters to control the training process (e.g., learning rate, mini-batch size, maximum epochs).

  • find_learning_rate: Uses the "learning rate finder" method to find an optimal learning rate for training. Parameters typically include the corpus, batch size, and a range of learning rates to explore.

  • final_test: After training a model, this method evaluates the model on a test set and prints the results.

  • save_checkpoint: Saves the current training state (including model parameters and training configurations) to resume later if interrupted.

  • load_checkpoint: Loads a previously saved checkpoint to resume training.

  • log_line: Utility method for logging. Writes a line to both the console and the log file.

  • log_section: Utility method for logging. Writes a section break to both the console and the log file.

Usage

flair_trainers()

Value

A Python Module(flair.trainers) object allowing access to Flair's trainers in R.

References

Flair GitHub Python equivalent:


from flair.trainers import ModelTrainer

Examples

if (FALSE) { # \dontrun{
trainers <- flair_trainers()
model_trainer <- trainers$ModelTrainer
} # }