Skip to content

The Overview

Flair NLP is an open-source library for Natural Language Processing (NLP) developed by Zalando Research. Known for its state-of-the-art solutions, such as contextual string embeddings for NLP tasks like Named Entity Recognition (NER), Part-of-Speech tagging (POS), and more, it has garnered attention in the NLP community for its ease of use and powerful functionalities.

In addition, Flair NLP offers pre-trained models for various languages and tasks, and is compatible with fine-tuned transformers hosted on Hugging Face.

 

Sentence and Token

Sentence and Token are fundamental classes.

Sentence

A Sentence in Flair is an object that contains a sequence of Token objects, and it can be annotated with labels, such as named entities, part-of-speech tags, and more. It also can store embeddings for the sentence as a whole and different kinds of linguistic annotations.

Here’s a simple example of how you create a Sentence:

# Creating a Sentence object
library(flaiR)
string <- "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland."
Sentence <- flair_data()$Sentence
sentence <- Sentence(string)

Sentence[26] means that there are a total of 26 tokens in the sentence.

print(sentence)
#> Sentence[26]: "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland."

Token

When you use Flair to handle text data,1 Sentence and Token objects often play central roles in many use cases. When you create a Sentence object, it usually automatically decomposes the internal raw text into multiple Token objects. In other words, the Sentence object automatically handles the text tokenization work, so you usually don’t need to create Token objects manually.

Unlike R, which indexes from 1, Python indexes from 0. Therefore, when I use a for loop, I use seq_along(sentence) - 1. The output should be something like:

# The Sentence object has automatically created and contains multiple Token objects
# We can iterate through the Sentence object to view each Token.

for (i in seq_along(sentence)-1) {
  print(sentence[[i]])
}
#> Token[0]: "What"
#> Token[1]: "I"
#> Token[2]: "see"
#> Token[3]: "in"
#> Token[4]: "UCD"
#> Token[5]: "today"
#> Token[6]: ","
#> Token[7]: "what"
#> Token[8]: "I"
#> Token[9]: "have"
#> Token[10]: "seen"
#> Token[11]: "of"
#> Token[12]: "UCD"
#> Token[13]: "in"
#> Token[14]: "its"
#> Token[15]: "impact"
#> Token[16]: "on"
#> Token[17]: "my"
#> Token[18]: "own"
#> Token[19]: "life"
#> Token[20]: "and"
#> Token[21]: "the"
#> Token[22]: "life"
#> Token[23]: "of"
#> Token[24]: "Ireland"
#> Token[25]: "."

Or you can directly use $tokens method to print all tokens.

print(sentence$tokens)
#> [[1]]
#> Token[0]: "What"
#> 
#> [[2]]
#> Token[1]: "I"
#> 
#> [[3]]
#> Token[2]: "see"
#> 
#> [[4]]
#> Token[3]: "in"
#> 
#> [[5]]
#> Token[4]: "UCD"
#> 
#> [[6]]
#> Token[5]: "today"
#> 
#> [[7]]
#> Token[6]: ","
#> 
#> [[8]]
#> Token[7]: "what"
#> 
#> [[9]]
#> Token[8]: "I"
#> 
#> [[10]]
#> Token[9]: "have"
#> 
#> [[11]]
#> Token[10]: "seen"
#> 
#> [[12]]
#> Token[11]: "of"
#> 
#> [[13]]
#> Token[12]: "UCD"
#> 
#> [[14]]
#> Token[13]: "in"
#> 
#> [[15]]
#> Token[14]: "its"
#> 
#> [[16]]
#> Token[15]: "impact"
#> 
#> [[17]]
#> Token[16]: "on"
#> 
#> [[18]]
#> Token[17]: "my"
#> 
#> [[19]]
#> Token[18]: "own"
#> 
#> [[20]]
#> Token[19]: "life"
#> 
#> [[21]]
#> Token[20]: "and"
#> 
#> [[22]]
#> Token[21]: "the"
#> 
#> [[23]]
#> Token[22]: "life"
#> 
#> [[24]]
#> Token[23]: "of"
#> 
#> [[25]]
#> Token[24]: "Ireland"
#> 
#> [[26]]
#> Token[25]: "."

Retrieve the Token

To comprehend the string representation format of the Sentence object, tagging at least one token is adequate. The get_token(n) method, a Python method, allows us to retrieve the Token object for a particular token. Additionally, we can use [] to index a specific token. It is noteworthy that Python indexes from 0, whereas R starts indexing from 1.

# method in Python
sentence$get_token(5)
#> Token[4]: "UCD"
# indexing in R 
sentence[4]
#> Token[4]: "UCD"

Each word (and punctuation) in the sentence is treated as an individual Token object. These Token objects store text information and other possible linguistic information (such as part-of-speech tags or named entity tags) and embeddings (if you used a model to generate them).

Even though in most cases you do not need to create Token objects manually, understanding how to manage these objects manually is still useful in some situations, such as when you want fine-grained control over the tokenization process. For example, you can control the exactness of tokenization by adding manually created Token objects to a Sentence object.

This design pattern in Flair allows users to handle text data in a very flexible way. Users can use the automatic tokenization feature for rapid development, and also perform finer-grained control to accommodate more use cases.

Annotate POS tag and NER tag

The add_label(label_type, value) method can be employed to assign a label to the token. We manually add a tag in this preliminary tutorial, so usually, in Universal POS tags, if sentence[10] is ‘see’, ‘seen’ might be tagged as VERB, indicating it is a past participle form of a verb.

sentence[10]$add_label('manual-pos', 'VERB')
print(sentence[10])
#> Token[10]: "seen" → VERB (1.0)

We can also add a NER (Named Entity Recognition) tag to sentence[4], “UCD”, identifying it as a university in Dublin.

sentence[4]$add_label('ner', 'ORG')
print(sentence[4])
#> Token[4]: "UCD" → ORG (1.0)

If we print the sentence object, Sentence[50] provides information for 50 tokens → [‘in’/ORG, ‘seen’/VERB], thus displaying two tagging pieces of information.

print(sentence)
#> Sentence[26]: "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland." → ["UCD"/ORG, "seen"/VERB]

Corpus

The Corpus object in Flair is a fundamental data structure that represents a dataset containing text samples, usually comprising of a training set, a development set (or validation set), and a test set. It’s designed to work smoothly with Flair’s models for tasks like named entity recognition, text classification, and more.

Attributes:

  • train: A list of sentences (List[Sentence]) that form the training dataset.
  • dev (or development): A list of sentences (List[Sentence]) that form the development (or validation) dataset.
  • test: A list of sentences (List[Sentence]) that form the test dataset.

Important Methods:

  • downsample: This method allows you to downsample (reduce) the number of sentences in the train, dev, and test splits. obtain_statistics: This method gives a quick overview of the statistics of the corpus, including the number of sentences and the distribution of labels.
  • make_vocab_dictionary: Used to create a vocabulary dictionary from the corpus.
library(flaiR)
Corpus <- flair_data()$Corpus
Sentence <- flair_data()$Sentence
# Create some example sentences
train <- list(Sentence('This is a training example.'))
dev <-  list(Sentence('This is a validation example.'))
test <- list(Sentence('This is a test example.'))

# Create a corpus using the custom data splits
corpus <-  Corpus(train = train, dev = dev, test = test)

$obtain_statistics() method of the Corpus object in the Flair library provides an overview of the dataset statistics. The method returns a Python’s dictionary with details about the training, validation (development), and test datasets that make up the corpus. In R, you can use the jsonlite package to format JSON.

library(jsonlite)
data <- fromJSON(corpus$obtain_statistics())
formatted_str <- toJSON(data, pretty=TRUE)
print(formatted_str)
#> {
#>   "TRAIN": {
#>     "dataset": ["TRAIN"],
#>     "total_number_of_documents": [1],
#>     "number_of_documents_per_class": {},
#>     "number_of_tokens_per_tag": {},
#>     "number_of_tokens": {
#>       "total": [6],
#>       "min": [6],
#>       "max": [6],
#>       "avg": [6]
#>     }
#>   },
#>   "TEST": {
#>     "dataset": ["TEST"],
#>     "total_number_of_documents": [1],
#>     "number_of_documents_per_class": {},
#>     "number_of_tokens_per_tag": {},
#>     "number_of_tokens": {
#>       "total": [6],
#>       "min": [6],
#>       "max": [6],
#>       "avg": [6]
#>     }
#>   },
#>   "DEV": {
#>     "dataset": ["DEV"],
#>     "total_number_of_documents": [1],
#>     "number_of_documents_per_class": {},
#>     "number_of_tokens_per_tag": {},
#>     "number_of_tokens": {
#>       "total": [6],
#>       "min": [6],
#>       "max": [6],
#>       "avg": [6]
#>     }
#>   }
#> }

In R

Below, we use data from the article The Temporal Focus of Campaign Communication by Stefan Muller, published in Journal of Politics in 2020, as an example.

First, I vectorize the cc_muller$text using the Sentence function to transform it into a list object. Then, we reformat cc_muller$class_pro_retro as a factor. It’s essential to note that R handles numerical values differently than Python. In R, numerical values are represented with a floating point, so it’s advisable to convert them into factors or strings. Lastly, I employ the map function from the purrr package to assign labels to each sentence corpus using the $add_label method.

library(purrr)
#> 
#> Attaching package: 'purrr'
#> The following object is masked from 'package:jsonlite':
#> 
#>     flatten
data(cc_muller)
# The `Sentence` object tokenizes text 
text <- lapply( cc_muller$text, Sentence)
# split sentence object to train and test. 
labels <- as.factor(cc_muller$class_pro_retro)
# `$add_label` method assigns the corresponding coded type to each Sentence corpus.
text <- map2(text, labels, ~ .x$add_label("classification", .y), .progress = TRUE)

To perform a train-test split using base R, we can follow these steps:

set.seed(2046)
sample <- sample(c(TRUE, FALSE), length(text), replace=TRUE, prob=c(0.8, 0.2))
train  <- text[sample]
test   <- text[!sample]
sprintf("Corpus object sizes - Train: %d |  Test: %d", length(train), length(test))
#> [1] "Corpus object sizes - Train: 4710 |  Test: 1148"

If you don’t provide a dev set, Flair won’t force you to carve out a portion of your test set to serve as a dev set. However, in some cases when only the train and test sets are provided without a dev set, Flair might automatically take a fraction of the train set (e.g., 10%) to use as a dev set (#2259). This is to offer a mechanism for model selection and early stopping to prevent the model from overfitting on the train set.

In the “Corpus” function, there is a random selection of “dev.” To ensure reproducibility, we need to set the seed in the Flair framework. We can accomplish this by calling the top-level module “flair” from via {flaiR} and using $set_seed(1964L) to set the seed.

flair <- import_flair()
flair$set_seed(1964L)
corpus <- Corpus(train=train, 
                 # dev=test,
                 test=test)
#> 2024-05-09 08:20:00,575 No dev split found. Using 0% (i.e. 471 samples) of the train split as dev data
sprintf("Corpus object sizes - Train: %d | Test: %d | Dev: %d", 
        length(corpus$train), 
        length(corpus$test),
        length(corpus$dev))
#> [1] "Corpus object sizes - Train: 4239 | Test: 1148 | Dev: 471"

In the later sections, there will be more similar processing using the Corpus. Following that, we will focus on advanced NLP applications.

 


Sequence Taggings

Tag Entities in Text

Let’s run named entity recognition over the following example sentence: “I love Berlin and New York”. To do this, all you need is to make a Sentence for this text, load a pre-trained model and use it to predict tags for the sentence object.

# attach flaiR in R
library(flaiR)

# make a sentence

Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')

# load the NER tagger
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('ner')
#> 2024-05-09 08:20:02,077 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>

# run NER over sentence
tagger$predict(sentence)

This should print:

# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → ["Berlin"/LOC, "New York"/LOC]

Use a for loop to print out each POS tag. It’s important to note that Python is indexed from 0. Therefore, in an R environment, we must use seq_along(sentence$get_labels()) - 1.

for (i in seq_along(sentence$get_labels())) {
      print(sentence$get_labels()[[i]])
  }
#> 'Span[2:3]: "Berlin"'/'LOC' (0.9812)
#> 'Span[4:6]: "New York"'/'LOC' (0.9957)

 

Tag Part-of-Speech in Text

We use flair/pos-english for POS tagging in the standard models on Hugging Face.

# attach flaiR in R
library(flaiR)

# make a sentence
Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')

# load the NER tagger
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('pos')
#> 2024-05-09 08:20:03,018 SequenceTagger predicts: Dictionary with 53 tags: <unk>, O, UH, ,, VBD, PRP, VB, PRP$, NN, RB, ., DT, JJ, VBP, VBG, IN, CD, NNS, NNP, WRB, VBZ, WDT, CC, TO, MD, VBN, WP, :, RP, EX, JJR, FW, XX, HYPH, POS, RBR, JJS, PDT, NNPS, RBS, AFX, WP$, -LRB-, -RRB-, ``, '', LS, $, SYM, ADD

# run NER over sentence
tagger$predict(sentence)

This should print:

# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → ["I"/PRP, "love"/VBP, "Berlin"/NNP, "and"/CC, "New"/NNP, "York"/NNP, "."/.]

Use a for loop to print out each pos tag.

for (i in seq_along(sentence$get_labels())) {
      print(sentence$get_labels()[[i]])
  }
#> 'Token[0]: "I"'/'PRP' (1.0)
#> 'Token[1]: "love"'/'VBP' (1.0)
#> 'Token[2]: "Berlin"'/'NNP' (0.9999)
#> 'Token[3]: "and"'/'CC' (1.0)
#> 'Token[4]: "New"'/'NNP' (1.0)
#> 'Token[5]: "York"'/'NNP' (1.0)
#> 'Token[6]: "."'/'.' (1.0)

 

Detect Sentiment

Let’s run sentiment analysis over the same sentence to determine whether it is POSITIVE or NEGATIVE.

You can do this with essentially the same code as above. Just instead of loading the ‘ner’ model, you now load the ‘sentiment’ model:

# attach flaiR in R
library(flaiR)

# make a sentence
Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')

# load the Classifier tagger from flair.nn module
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('sentiment')

# run sentiment analysis over sentence
tagger$predict(sentence)
# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → POSITIVE (0.9982)

 


 

Tagging Parts-of-Speech with Flair Models

You can load the pre-trained model "pos-fast". For more pre-trained models, see https://flairnlp.github.io/docs/tutorial-basics/part-of-speech-tagging#-in-english.

texts <- c("UCD is one of the best universities in Ireland.",
           "UCD has a good campus but is very far from my apartment in Dublin.",
           "Essex is famous for social science research.",
           "Essex is not in the Russell Group, but it is famous for political science research and in 1994 Group.",
           "TCD is the oldest university in Ireland.",
           "TCD is similar to Oxford.")

doc_ids <- c("doc1", "doc2", "doc3", "doc4", "doc5", "doc6")
library(flaiR)
tagger_pos <- load_tagger_pos("pos-fast")
#> 2024-05-09 08:20:05,636 SequenceTagger predicts: Dictionary with 53 tags: <unk>, O, UH, ,, VBD, PRP, VB, PRP$, NN, RB, ., DT, JJ, VBP, VBG, IN, CD, NNS, NNP, WRB, VBZ, WDT, CC, TO, MD, VBN, WP, :, RP, EX, JJR, FW, XX, HYPH, POS, RBR, JJS, PDT, NNPS, RBS, AFX, WP$, -LRB-, -RRB-, ``, '', LS, $, SYM, ADD
results <- get_pos(texts, doc_ids, tagger_pos)
head(results, n = 10)
#>     doc_id token_id text_id        token    tag precision
#>     <char>    <num>  <lgcl>       <char> <char>     <num>
#>  1:   doc1        0      NA          UCD    NNP    0.9967
#>  2:   doc1        1      NA           is    VBZ    1.0000
#>  3:   doc1        2      NA          one     CD    0.9993
#>  4:   doc1        3      NA           of     IN    1.0000
#>  5:   doc1        4      NA          the     DT    1.0000
#>  6:   doc1        5      NA         best    JJS    0.9988
#>  7:   doc1        6      NA universities    NNS    0.9997
#>  8:   doc1        7      NA           in     IN    1.0000
#>  9:   doc1        8      NA      Ireland    NNP    1.0000
#> 10:   doc1        9      NA            .      .    0.9998

 

Tagging Entities with Flair Models

Load the pretrained model ner. For more pretrained models, see https://flairnlp.github.io/docs/tutorial-basics/tagging-entities.

library(flaiR)
tagger_ner <- load_tagger_ner("ner")
#> 2024-05-09 08:20:06,828 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>
results <- get_entities(texts, doc_ids, tagger_ner)
head(results, n = 10)
#>     doc_id        entity    tag
#>     <char>        <char> <char>
#>  1:   doc1           UCD    ORG
#>  2:   doc1       Ireland    LOC
#>  3:   doc2           UCD    ORG
#>  4:   doc2        Dublin    LOC
#>  5:   doc3         Essex    ORG
#>  6:   doc4         Essex    ORG
#>  7:   doc4 Russell Group    ORG
#>  8:   doc5           TCD    ORG
#>  9:   doc5       Ireland    LOC
#> 10:   doc6           TCD    ORG

 

Tagging Sentiment

Load the pretrained model “sentiment”. The pre-trained models of “sentiment”, “sentiment-fast”, and “de-offensive-language” are currently available. For more pre-trained models, see https://flairnlp.github.io/docs/tutorial-basics/tagging-sentiment.

library(flaiR)
tagger_sent <- load_tagger_sentiments("sentiment")
results <- get_sentiments(texts, doc_ids, tagger_sent)
head(results, n = 10)
#>    doc_id sentiment     score
#>    <char>    <char>     <num>
#> 1:   doc1  POSITIVE 0.9970598
#> 2:   doc2  NEGATIVE 0.8472336
#> 3:   doc3  POSITIVE 0.9928006
#> 4:   doc4  POSITIVE 0.9901404
#> 5:   doc5  POSITIVE 0.9952670
#> 6:   doc6  POSITIVE 0.9291797

 


 


Flair Embedding

Flair is a very popular natural language processing library, providing a variety of embedding methods for text representation through Flair. Flair Embeddings is a word embedding framowork in Natural Language Processing, developed by the Zalando. Flair focuses on word-level representation and can capture contextual information of words, meaning that the same word can have different embeddings in different contexts. Unlike traditional word embeddings (such as Word2Vec or GloVe), Flair can dynamically generate word embeddings based on context and has achieved excellent results in various NLP tasks. Below are some key points about Flair Embeddings:

Context-Aware

Flair can understand the context of a word in a sentence and dynamically generate word embeddings based on this context. This is different from static embeddings, where the embedding of a word does not consider its context in a sentence.

Flair is a dynamic word embedding technique that can understand the meaning of words based on context. In contrast, static word embeddings, such as Word2Vec or GloVe, provide a fixed embedding for each word without considering its context in a sentence.

Therefore, context-sensitive embedding techniques, such as Flair, can capture the meaning of words in specific sentences more accurately, thus enhancing the performance of language models in various tasks.

Example:

Consider the following two English sentences:

  • “I am interested in the bank of the river.”
  • “I need to go to the bank to withdraw money.”

Here, the word “bank” has two different meanings. In the first sentence, it refers to the edge or shore of a river. In the second sentence, it refers to a financial institution.

For static embeddings, the word “bank” might have an embedding that lies somewhere between these two meanings because it doesn’t consider context. But for dynamic embeddings like Flair, “bank” in the first sentence will have an embedding related to rivers, and in the second sentence, it will have an embedding related to finance.

FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
Sentence <- flair_data()$Sentence


# Initialize Flair embeddings
flair_embedding_forward <- FlairEmbeddings('news-forward')

# Define the two sentences
sentence1 <-  Sentence("I am interested in the bank of the river.")
sentence2 <-  Sentence("I need to go to the bank to withdraw money.")

# Get the embeddings

flair_embedding_forward$embed(sentence1)
#> [[1]]
#> Sentence[10]: "I am interested in the bank of the river."
flair_embedding_forward$embed(sentence2)
#> [[1]]
#> Sentence[11]: "I need to go to the bank to withdraw money."

# Extract the embedding for "bank" from the sentences
bank_embedding_sentence1 = sentence1[5]$embedding  # "bank" is the seventh word
bank_embedding_sentence2 = sentence2[6]$embedding  # "bank" is the sixth word

Same word, similar vector representation, but essentially different. In this way, you can see how the dynamic embeddings for “bank” in the two sentences differ based on context. Although we printed the embeddings here, in reality, they would be high-dimensional vectors, so you might see a lot of numbers. If you want a more intuitive view of the differences, you could compute the cosine similarity or other metrics between the two embeddings.

This is just a simple demonstration. In practice, you can also combine multiple embedding techniques, such as WordEmbeddings and FlairEmbeddings, to get richer word vectors.

library(lsa)
#> Loading required package: SnowballC
cosine(as.numeric( bank_embedding_sentence1$numpy()), 
       as.numeric( bank_embedding_sentence2$numpy()))
#>           [,1]
#> [1,] 0.7329551

Character-Based

Flair uses a character-level language model, meaning it can generate embeddings for rare words or even misspelled words. This is an important feature because it allows the model to understand and process words that have never appeared in the training data. Flair uses a bidirectional LSTM (Long Short-Term Memory) network that operates at a character level. This means that it feeds individual characters into the LSTM instead of words.

Multilingual Support

Flair provides various pre-trained character-level language models, supporting contextual word embeddings for multiple languages. Flair allows you to easily combine different word embeddings (e.g., Flair Embeddings, Word2Vec, GloVe, etc.) to create powerful stacked embeddings.

Classic Wordembeddings

In Flair, the simplest form of embedding that still contains semantic information about the word is called classic word embeddings. These embeddings are pre-trained and non-contextual. Let’s retrieve a few word embeddings. Then, we can utilize FastText embeddings with the following code. To use them, we simply instantiate a WordEmbeddings class by passing in the ID of the embedding of our choice. Then, we simply wrap our text into a Sentence object, and call the embed(sentence) method on our WordEmbeddings class.

WordEmbeddings <- flair_embeddings()$WordEmbeddings
Sentence <- flair_data()$Sentence
embedding <- WordEmbeddings('crawl')
sentence <- Sentence("one two three one") 
embedding$embed(sentence) 
#> [[1]]
#> Sentence[4]: "one two three one"

for (i in seq_along(sentence$tokens)) {
  print(head(sentence$tokens[[i]]$embedding), n =5)
}
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486,  0.2383])
#> tensor([ 0.0282, -0.0786, -0.1236,  0.1756, -0.1199,  0.0964])
#> tensor([-0.0920, -0.0690, -0.1475,  0.2313, -0.0872,  0.0799])
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486,  0.2383])

Flair supports a range of classic word embeddings, each offering unique features and application scopes. Below is an overview, detailing the ID required to load each embedding and its corresponding language.

Embedding Type ID Language
GloVe glove English
Komninos extvec English
Twitter twitter English
Turian (small) turian English
FastText (crawl) crawl English
FastText (news & Wikipedia) ar Arabic
FastText (news & Wikipedia) bg Bulgarian
FastText (news & Wikipedia) ca Catalan
FastText (news & Wikipedia) cz Czech
FastText (news & Wikipedia) da Danish
FastText (news & Wikipedia) de German
FastText (news & Wikipedia) es Spanish
FastText (news & Wikipedia) en English
FastText (news & Wikipedia) eu Basque
FastText (news & Wikipedia) fa Persian
FastText (news & Wikipedia) fi Finnish
FastText (news & Wikipedia) fr French
FastText (news & Wikipedia) he Hebrew
FastText (news & Wikipedia) hi Hindi
FastText (news & Wikipedia) hr Croatian
FastText (news & Wikipedia) id Indonesian
FastText (news & Wikipedia) it Italian
FastText (news & Wikipedia) ja Japanese
FastText (news & Wikipedia) ko Korean
FastText (news & Wikipedia) nl Dutch
FastText (news & Wikipedia) no Norwegian
FastText (news & Wikipedia) pl Polish
FastText (news & Wikipedia) pt Portuguese
FastText (news & Wikipedia) ro Romanian
FastText (news & Wikipedia) ru Russian
FastText (news & Wikipedia) si Slovenian
FastText (news & Wikipedia) sk Slovak
FastText (news & Wikipedia) sr Serbian
FastText (news & Wikipedia) sv Swedish
FastText (news & Wikipedia) tr Turkish
FastText (news & Wikipedia) zh Chinese

 


Contexual Embeddings

Understanding the contextuality of Flair embeddings The idea behind contextual string embeddings is that each word embedding should be defined by not only its syntactic-semantic meaning but also the context it appears in. What this means is that each word will have a different embedding for every context it appears in. Each pre-trained Flair model offers a forward version and a backward version. Let’s assume you are processing a language that, just like this book, uses the left-to-right script. The forward version takes into account the context that happens before the word – on the left-hand side. The backward version works in the opposite direction. It takes into account the context after the word – on the right-hand side of the word. If this is true, then two same words that appear at the beginning of two different sentences should have identical forward embeddings, because their context is null. Let’s test this out:

Because we are using a forward model, it only takes into account the context that occurs before a word. Additionally, since our word has no context on the left-hand side of its position in the sentence, the two embeddings are identical, and the code assumes they are identical, indeed output is True.

FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
embedding <- FlairEmbeddings('news-forward')
s1 <- Sentence("nice shirt") 
s2 <- Sentence("nice pants") 

embedding$embed(s1) 
#> [[1]]
#> Sentence[2]: "nice shirt"
embedding$embed(s2) 
#> [[1]]
#> Sentence[2]: "nice pants"
cat(" s1 sentence:", paste(s1[0], sep = ""), "\n", "s2 sentence:", paste(s2[0], sep = ""))
#>  s1 sentence: Token[0]: "nice" 
#>  s2 sentence: Token[0]: "nice"

We test whether the sum of the two 2048 embeddings of nice is equal to 2048. If it is true, it indicates that the embedding results are consistent, which should theoretically be the case.

length(s1[0]$embedding$numpy()) == sum(s1[0]$embedding$numpy() ==  s2[0]$embedding$numpy())
#> [1] TRUE

Now we separately add a few more words, very and pretty, into two sentence objects.

s1 <- Sentence("very nice shirt") 
s2 <- Sentence("pretty nice pants") 

embedding$embed(s1) 
#> [[1]]
#> Sentence[3]: "very nice shirt"
embedding$embed(s2) 
#> [[1]]
#> Sentence[3]: "pretty nice pants"

The two sets of embeddings are not identical because the words are different, so it returns FALSE.

length(s1[0]$embedding$numpy()) == sum(s1[0]$embedding$numpy() ==  s2[0]$embedding$numpy())
#> [1] FALSE

The measure of similarity between two vectors in an inner product space is known as cosine similarity. The formula for calculating cosine similarity between two vectors, such as vectors A and B, is as follows:

\(Cosine Similarity = \frac{\sum_{i} (A_i \cdot B_i)}{\sqrt{\sum_{i} (A_i^2)} \cdot \sqrt{\sum_{i} (B_i^2)}}\)

library(lsa)
vector1 <- as.numeric(s1[0]$embedding$numpy())
vector2 <- as.numeric(s2[0]$embedding$numpy())

We can observe that the similarity between the two words is 0.55.

cosine_similarity <- cosine(vector1, vector2)
print(cosine_similarity)
#>           [,1]
#> [1,] 0.5571664

 


Extracting Embeddings from BERT

First, we utilize the flair.embeddings.TransformerWordEmbeddings function to download BERT, and more transformer models can also be found on Flair NLP’s Hugging Face.

TransformerWordEmbeddings <- flair_embeddings.TransformerWordEmbeddings("bert-base-uncased")
embedding <- TransformerWordEmbeddings$embed(sentence)

Traverse each token in the sentence and print them. To view each token, it’s necessary to usereticulate::py_str(token) since the sentence is a Python object.

# Iterate through each token in the sentence, printing them. 
# Utilize reticulate::py_str(token) to view each token, given that the sentence is a Python object.
for (i in seq_along(sentence$tokens)) {
  cat("Token: ", reticulate::py_str(sentence$tokens[[i]]), "\n")
  # Access the embedding of the token, converting it to an R object, 
  # and print the first 10 elements of the vector.
  token_embedding <- sentence$tokens[[i]]$embedding
  print(head(token_embedding, 10))
}
#> Token:  Token[0]: "one" 
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486,  0.2383, -0.1200,  0.2620,
#>         -0.0575,  0.0228])
#> Token:  Token[1]: "two" 
#> tensor([ 0.0282, -0.0786, -0.1236,  0.1756, -0.1199,  0.0964, -0.1327,  0.4449,
#>         -0.0264, -0.1168])
#> Token:  Token[2]: "three" 
#> tensor([-0.0920, -0.0690, -0.1475,  0.2313, -0.0872,  0.0799, -0.0901,  0.4403,
#>         -0.0103, -0.1494])
#> Token:  Token[3]: "one" 
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486,  0.2383, -0.1200,  0.2620,
#>         -0.0575,  0.0228])

 


Training a Binary Classifier

In this section, we’ll train a sentiment analysis model that can categorize text as either positive or negative. This case study is adapted from pages 116 to 130 of Tadej Magajna’s book, ‘Natural Language Processing with Flair’. The process for training text classifiers in Flair mirrors the sequence followed for sequence labeling models. Specifically, the steps to train text classifiers are:

  • Load a tagged corpus and compute the label dictionary map.
  • Prepare the document embeddings.
  • Initialize the TextClassifier class.
  • Train the model.

Loading a Tagged Corpus

Training text classification models requires a set of text documents (typically, sentences or paragraphs) where each document is associated with one or more classification labels. To train our sentiment analysis text classification model, we will be using the famous Internet Movie Database (IMDb) dataset, which contains 50,000 movie reviews from IMDB, where each review is labeled as either positive or negative. References to this dataset are already baked into Flair, so loading the dataset couldn’t be easier:

library(flaiR)
# load IMDB from flair_datasets module
Corpus <- flair_data()$Corpus
IMDB <- flair_datasets()$IMDB
# downsize to 0.05
corpus = IMDB()
#> 2024-05-09 08:20:17,004 Reading data from /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced
#> 2024-05-09 08:20:17,004 Train: /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced/train.txt
#> 2024-05-09 08:20:17,004 Dev: None
#> 2024-05-09 08:20:17,004 Test: None
#> 2024-05-09 08:20:17,537 No test split found. Using 0% (i.e. 5000 samples) of the train split as test data
#> 2024-05-09 08:20:17,550 No dev split found. Using 0% (i.e. 4500 samples) of the train split as dev data
#> 2024-05-09 08:20:17,550 Initialized corpus /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced (label type name is 'sentiment')
corpus$downsample(0.05)
#> <flair.datasets.document_classification.IMDB object at 0x37fde3a60>

Print the sizes in the corpus object as follows - test: %d | train: %d | dev: %d”

test_size <- length(corpus$test)
train_size <- length(corpus$train)
dev_size <- length(corpus$dev)
output <- sprintf("Corpus object sizes - Test: %d | Train: %d | Dev: %d", test_size, train_size, dev_size)
print(output)
#> [1] "Corpus object sizes - Test: 250 | Train: 2025 | Dev: 225"
lbl_type = 'sentiment'
label_dict = corpus$make_label_dictionary(label_type=lbl_type)
#> 2024-05-09 08:20:17,675 Computing label dictionary. Progress:
#> 2024-05-09 08:20:21,249 Dictionary created for label 'sentiment' with 2 values: POSITIVE (seen 1014 times), NEGATIVE (seen 1011 times)

Loading the Embeddings

In flair, it covers all the different types of document embeddings that we can use. Here, we simply use DocumentPoolEmbeddings. They require no training prior to training the classification model itself:

DocumentPoolEmbeddings <- flair_embeddings()$DocumentPoolEmbeddings
WordEmbeddings <- flair_embeddings()$WordEmbeddings
glove = WordEmbeddings('glove')
document_embeddings = DocumentPoolEmbeddings(glove)

Initializing the TextClassifier

# initiate TextClassifier
TextClassifier <- flair_models()$TextClassifier
classifier <- TextClassifier(document_embeddings,
                             label_dictionary = label_dict,
                             label_type = lbl_type)

$to allows you to set the device to use CPU, GPU, or specific MPS devices on Mac (such as mps:0, mps:1, mps:2).

classifier$to(flair_device("mps")) 
TextClassifier(
  (embeddings): DocumentPoolEmbeddings(
    fine_tune_mode=none, pooling=mean
    (embeddings): StackedEmbeddings(
      (list_embedding_0): WordEmbeddings(
        'glove'
        (embedding): Embedding(400001, 100)
      )
    )
  )
  (decoder): Linear(in_features=100, out_features=3, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
)

Training the Model

Training the text classifier model involves two simple steps: - Defining the model trainer class by passing in the classifier model and the corpus - Setting off the training process passing in the required training hyperparameters.

It is worth noting that the ‘L’ in numbers like 32L and 5L is used in R to denote that the number is an integer. Without the ‘L’ suffix, numbers in R are treated as numeric, which are by default double-precision floating-point numbers. In contrast, Python determines the type based on the value of the number itself. Whole numbers (e.g., 5 or 32) are of type int, while numbers with decimal points (e.g., 5.0) are of type float. Floating-point numbers in both languages are representations of real numbers but can have some approximation due to the way they are stored in memory.

# initiate ModelTrainer
ModelTrainer <- flair_trainers()$ModelTrainer

# fit the model
trainer <- ModelTrainer(classifier, corpus)

# start to train
# note: the 'L' in 32L is used in R to denote that the number is an integer.
trainer$train('classifier',
              learning_rate=0.1,
              mini_batch_size=32L,
              # specifies how embeddings are stored in RAM, ie."cpu", "cuda", "gpu", "mps".
              # embeddings_storage_mode = "mps",
              max_epochs=10L)
#> 2024-05-09 08:20:22,897 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,897 Model: "TextClassifier(
#>   (embeddings): DocumentPoolEmbeddings(
#>     fine_tune_mode=none, pooling=mean
#>     (embeddings): StackedEmbeddings(
#>       (list_embedding_0): WordEmbeddings(
#>         'glove'
#>         (embedding): Embedding(400001, 100)
#>       )
#>     )
#>   )
#>   (decoder): Linear(in_features=100, out_features=2, bias=True)
#>   (dropout): Dropout(p=0.0, inplace=False)
#>   (locked_dropout): LockedDropout(p=0.0)
#>   (word_dropout): WordDropout(p=0.0)
#>   (loss_function): CrossEntropyLoss()
#>   (weights): None
#>   (weight_tensor) None
#> )"
#> 2024-05-09 08:20:22,897 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,897 Corpus: 2025 train + 225 dev + 250 test sentences
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Train:  2025 sentences
#> 2024-05-09 08:20:22,898         (train_with_dev=False, train_with_test=False)
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Training Params:
#> 2024-05-09 08:20:22,898  - learning_rate: "0.1" 
#> 2024-05-09 08:20:22,898  - mini_batch_size: "32"
#> 2024-05-09 08:20:22,898  - max_epochs: "10"
#> 2024-05-09 08:20:22,898  - shuffle: "True"
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Plugins:
#> 2024-05-09 08:20:22,898  - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Final evaluation on model from best epoch (best-model.pt)
#> 2024-05-09 08:20:22,898  - metric: "('micro avg', 'f1-score')"
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Computation:
#> 2024-05-09 08:20:22,898  - compute on device: cpu
#> 2024-05-09 08:20:22,898  - embedding storage: cpu
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 Model training base path: "classifier"
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:22,898 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:23,469 epoch 1 - iter 6/64 - loss 0.92798599 - time (sec): 0.57 - samples/sec: 336.31 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:24,406 epoch 1 - iter 12/64 - loss 0.95401138 - time (sec): 1.51 - samples/sec: 254.77 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:25,155 epoch 1 - iter 18/64 - loss 0.95685528 - time (sec): 2.26 - samples/sec: 255.29 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:25,952 epoch 1 - iter 24/64 - loss 0.94917999 - time (sec): 3.05 - samples/sec: 251.55 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:26,666 epoch 1 - iter 30/64 - loss 0.94337714 - time (sec): 3.77 - samples/sec: 254.81 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:27,374 epoch 1 - iter 36/64 - loss 0.95555888 - time (sec): 4.48 - samples/sec: 257.38 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:28,279 epoch 1 - iter 42/64 - loss 0.94315202 - time (sec): 5.38 - samples/sec: 249.79 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:29,038 epoch 1 - iter 48/64 - loss 0.92537404 - time (sec): 6.14 - samples/sec: 250.19 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:29,776 epoch 1 - iter 54/64 - loss 0.92101984 - time (sec): 6.88 - samples/sec: 251.26 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:30,530 epoch 1 - iter 60/64 - loss 0.92079034 - time (sec): 7.63 - samples/sec: 251.60 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:30,825 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:30,825 EPOCH 1 done: loss 0.9182 - lr: 0.100000
#> 2024-05-09 08:20:32,049 DEV : loss 0.9580778479576111 - f1-score (micro avg)  0.4533
#> 2024-05-09 08:20:32,674  - 0 epochs without improvement
#> 2024-05-09 08:20:32,675 saving best model
#> 2024-05-09 08:20:33,046 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:33,844 epoch 2 - iter 6/64 - loss 0.89372043 - time (sec): 0.80 - samples/sec: 240.51 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:34,558 epoch 2 - iter 12/64 - loss 0.87390463 - time (sec): 1.51 - samples/sec: 254.07 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:35,456 epoch 2 - iter 18/64 - loss 0.90945973 - time (sec): 2.41 - samples/sec: 239.03 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:36,250 epoch 2 - iter 24/64 - loss 0.92023720 - time (sec): 3.20 - samples/sec: 239.69 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:37,067 epoch 2 - iter 30/64 - loss 0.92514712 - time (sec): 4.02 - samples/sec: 238.73 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:37,820 epoch 2 - iter 36/64 - loss 0.91756206 - time (sec): 4.77 - samples/sec: 241.29 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:38,551 epoch 2 - iter 42/64 - loss 0.91224597 - time (sec): 5.51 - samples/sec: 244.14 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:39,281 epoch 2 - iter 48/64 - loss 0.90670522 - time (sec): 6.23 - samples/sec: 246.36 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:40,060 epoch 2 - iter 54/64 - loss 0.90671374 - time (sec): 7.01 - samples/sec: 246.37 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:40,817 epoch 2 - iter 60/64 - loss 0.90005447 - time (sec): 7.77 - samples/sec: 247.08 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:41,428 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:41,428 EPOCH 2 done: loss 0.8971 - lr: 0.100000
#> 2024-05-09 08:20:42,641 DEV : loss 0.9348878264427185 - f1-score (micro avg)  0.4533
#> 2024-05-09 08:20:43,457  - 0 epochs without improvement
#> 2024-05-09 08:20:43,458 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:44,353 epoch 3 - iter 6/64 - loss 0.87230729 - time (sec): 0.90 - samples/sec: 214.51 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:45,147 epoch 3 - iter 12/64 - loss 0.88936495 - time (sec): 1.69 - samples/sec: 227.33 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:45,898 epoch 3 - iter 18/64 - loss 0.86102090 - time (sec): 2.44 - samples/sec: 236.12 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:46,659 epoch 3 - iter 24/64 - loss 0.84525002 - time (sec): 3.20 - samples/sec: 239.98 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:47,611 epoch 3 - iter 30/64 - loss 0.84292530 - time (sec): 4.15 - samples/sec: 231.15 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:48,441 epoch 3 - iter 36/64 - loss 0.82797504 - time (sec): 4.98 - samples/sec: 231.19 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:49,191 epoch 3 - iter 42/64 - loss 0.83439555 - time (sec): 5.73 - samples/sec: 234.44 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:49,957 epoch 3 - iter 48/64 - loss 0.84227631 - time (sec): 6.50 - samples/sec: 236.35 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:50,822 epoch 3 - iter 54/64 - loss 0.84142225 - time (sec): 7.36 - samples/sec: 234.65 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:51,714 epoch 3 - iter 60/64 - loss 0.84776874 - time (sec): 8.26 - samples/sec: 232.56 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:52,286 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:52,286 EPOCH 3 done: loss 0.8506 - lr: 0.100000
#> 2024-05-09 08:20:53,487 DEV : loss 0.7087231874465942 - f1-score (micro avg)  0.5511
#> 2024-05-09 08:20:54,162  - 0 epochs without improvement
#> 2024-05-09 08:20:54,163 saving best model
#> 2024-05-09 08:20:54,459 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:20:55,248 epoch 4 - iter 6/64 - loss 0.83985214 - time (sec): 0.79 - samples/sec: 243.49 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:56,040 epoch 4 - iter 12/64 - loss 0.85478169 - time (sec): 1.58 - samples/sec: 242.92 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:56,869 epoch 4 - iter 18/64 - loss 0.88271719 - time (sec): 2.41 - samples/sec: 239.04 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:57,610 epoch 4 - iter 24/64 - loss 0.87416061 - time (sec): 3.15 - samples/sec: 243.70 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:58,652 epoch 4 - iter 30/64 - loss 0.88960058 - time (sec): 4.19 - samples/sec: 228.94 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:20:59,371 epoch 4 - iter 36/64 - loss 0.88747313 - time (sec): 4.91 - samples/sec: 234.55 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:00,165 epoch 4 - iter 42/64 - loss 0.88250834 - time (sec): 5.71 - samples/sec: 235.56 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:00,943 epoch 4 - iter 48/64 - loss 0.87871186 - time (sec): 6.48 - samples/sec: 236.89 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:01,707 epoch 4 - iter 54/64 - loss 0.87959897 - time (sec): 7.25 - samples/sec: 238.40 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:02,396 epoch 4 - iter 60/64 - loss 0.87840113 - time (sec): 7.94 - samples/sec: 241.91 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:02,866 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:02,866 EPOCH 4 done: loss 0.8811 - lr: 0.100000
#> 2024-05-09 08:21:03,884 DEV : loss 0.7050172686576843 - f1-score (micro avg)  0.5511
#> 2024-05-09 08:21:04,515  - 0 epochs without improvement
#> 2024-05-09 08:21:04,516 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:05,297 epoch 5 - iter 6/64 - loss 0.91935858 - time (sec): 0.78 - samples/sec: 245.85 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:06,007 epoch 5 - iter 12/64 - loss 0.90264572 - time (sec): 1.49 - samples/sec: 257.49 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:06,752 epoch 5 - iter 18/64 - loss 0.89750031 - time (sec): 2.24 - samples/sec: 257.60 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:07,513 epoch 5 - iter 24/64 - loss 0.90278538 - time (sec): 3.00 - samples/sec: 256.23 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:08,284 epoch 5 - iter 30/64 - loss 0.89582215 - time (sec): 3.77 - samples/sec: 254.77 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:09,058 epoch 5 - iter 36/64 - loss 0.89115169 - time (sec): 4.54 - samples/sec: 253.62 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:09,758 epoch 5 - iter 42/64 - loss 0.88016130 - time (sec): 5.24 - samples/sec: 256.37 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:10,486 epoch 5 - iter 48/64 - loss 0.87231719 - time (sec): 5.97 - samples/sec: 257.28 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:11,257 epoch 5 - iter 54/64 - loss 0.87672888 - time (sec): 6.74 - samples/sec: 256.35 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:12,316 epoch 5 - iter 60/64 - loss 0.86810185 - time (sec): 7.80 - samples/sec: 246.14 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:12,603 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:12,603 EPOCH 5 done: loss 0.8675 - lr: 0.100000
#> 2024-05-09 08:21:13,615 DEV : loss 0.7054558396339417 - f1-score (micro avg)  0.5689
#> 2024-05-09 08:21:14,238  - 0 epochs without improvement
#> 2024-05-09 08:21:14,239 saving best model
#> 2024-05-09 08:21:14,519 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:15,277 epoch 6 - iter 6/64 - loss 0.75694018 - time (sec): 0.76 - samples/sec: 253.36 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:16,019 epoch 6 - iter 12/64 - loss 0.80699679 - time (sec): 1.50 - samples/sec: 256.06 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:16,762 epoch 6 - iter 18/64 - loss 0.82347439 - time (sec): 2.24 - samples/sec: 256.84 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:17,854 epoch 6 - iter 24/64 - loss 0.81694798 - time (sec): 3.33 - samples/sec: 230.32 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:18,549 epoch 6 - iter 30/64 - loss 0.81142912 - time (sec): 4.03 - samples/sec: 238.25 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:19,285 epoch 6 - iter 36/64 - loss 0.80620797 - time (sec): 4.76 - samples/sec: 241.76 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:20,048 epoch 6 - iter 42/64 - loss 0.81629599 - time (sec): 5.53 - samples/sec: 243.10 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:20,889 epoch 6 - iter 48/64 - loss 0.81234421 - time (sec): 6.37 - samples/sec: 241.15 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:21,672 epoch 6 - iter 54/64 - loss 0.81319698 - time (sec): 7.15 - samples/sec: 241.61 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:22,432 epoch 6 - iter 60/64 - loss 0.82112388 - time (sec): 7.91 - samples/sec: 242.65 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:22,939 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:22,940 EPOCH 6 done: loss 0.8152 - lr: 0.100000
#> 2024-05-09 08:21:23,968 DEV : loss 0.9520094394683838 - f1-score (micro avg)  0.4533
#> 2024-05-09 08:21:24,614  - 1 epochs without improvement
#> 2024-05-09 08:21:24,614 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:25,441 epoch 7 - iter 6/64 - loss 0.79382354 - time (sec): 0.83 - samples/sec: 232.45 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:26,172 epoch 7 - iter 12/64 - loss 0.81419498 - time (sec): 1.56 - samples/sec: 246.59 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:26,988 epoch 7 - iter 18/64 - loss 0.80050886 - time (sec): 2.37 - samples/sec: 242.74 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:27,698 epoch 7 - iter 24/64 - loss 0.78892901 - time (sec): 3.08 - samples/sec: 249.12 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:28,487 epoch 7 - iter 30/64 - loss 0.79873842 - time (sec): 3.87 - samples/sec: 247.89 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:29,452 epoch 7 - iter 36/64 - loss 0.79810941 - time (sec): 4.84 - samples/sec: 238.17 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:30,179 epoch 7 - iter 42/64 - loss 0.80939396 - time (sec): 5.56 - samples/sec: 241.54 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:30,936 epoch 7 - iter 48/64 - loss 0.80912796 - time (sec): 6.32 - samples/sec: 242.99 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:31,674 epoch 7 - iter 54/64 - loss 0.81644665 - time (sec): 7.06 - samples/sec: 244.79 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:32,432 epoch 7 - iter 60/64 - loss 0.82480310 - time (sec): 7.82 - samples/sec: 245.62 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:32,945 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:32,945 EPOCH 7 done: loss 0.8255 - lr: 0.100000
#> 2024-05-09 08:21:33,945 DEV : loss 0.9209567308425903 - f1-score (micro avg)  0.4533
#> 2024-05-09 08:21:34,563  - 2 epochs without improvement
#> 2024-05-09 08:21:34,563 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:35,384 epoch 8 - iter 6/64 - loss 0.81569169 - time (sec): 0.82 - samples/sec: 234.13 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:36,166 epoch 8 - iter 12/64 - loss 0.81156640 - time (sec): 1.60 - samples/sec: 239.60 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:36,949 epoch 8 - iter 18/64 - loss 0.81359380 - time (sec): 2.39 - samples/sec: 241.42 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:37,680 epoch 8 - iter 24/64 - loss 0.81111881 - time (sec): 3.12 - samples/sec: 246.39 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:38,370 epoch 8 - iter 30/64 - loss 0.80478209 - time (sec): 3.81 - samples/sec: 252.21 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:39,098 epoch 8 - iter 36/64 - loss 0.79061179 - time (sec): 4.53 - samples/sec: 254.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:39,829 epoch 8 - iter 42/64 - loss 0.79852152 - time (sec): 5.27 - samples/sec: 255.24 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:40,631 epoch 8 - iter 48/64 - loss 0.80004252 - time (sec): 6.07 - samples/sec: 253.13 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:41,449 epoch 8 - iter 54/64 - loss 0.80379323 - time (sec): 6.89 - samples/sec: 250.96 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:42,532 epoch 8 - iter 60/64 - loss 0.80295331 - time (sec): 7.97 - samples/sec: 240.94 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:42,849 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:42,849 EPOCH 8 done: loss 0.8067 - lr: 0.100000
#> 2024-05-09 08:21:44,214 DEV : loss 0.8891550302505493 - f1-score (micro avg)  0.4533
#> 2024-05-09 08:21:44,653  - 3 epochs without improvement
#> 2024-05-09 08:21:44,653 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:45,516 epoch 9 - iter 6/64 - loss 0.79579540 - time (sec): 0.86 - samples/sec: 222.62 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:46,322 epoch 9 - iter 12/64 - loss 0.77182314 - time (sec): 1.67 - samples/sec: 230.18 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:47,020 epoch 9 - iter 18/64 - loss 0.77825414 - time (sec): 2.37 - samples/sec: 243.43 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:47,973 epoch 9 - iter 24/64 - loss 0.77474442 - time (sec): 3.32 - samples/sec: 231.34 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:48,704 epoch 9 - iter 30/64 - loss 0.76791755 - time (sec): 4.05 - samples/sec: 237.00 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:49,431 epoch 9 - iter 36/64 - loss 0.78972670 - time (sec): 4.78 - samples/sec: 241.16 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:50,222 epoch 9 - iter 42/64 - loss 0.78559951 - time (sec): 5.57 - samples/sec: 241.37 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:50,950 epoch 9 - iter 48/64 - loss 0.78007192 - time (sec): 6.30 - samples/sec: 243.96 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:51,745 epoch 9 - iter 54/64 - loss 0.77606843 - time (sec): 7.09 - samples/sec: 243.66 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:52,518 epoch 9 - iter 60/64 - loss 0.78353175 - time (sec): 7.86 - samples/sec: 244.14 - lr: 0.100000 - momentum: 0.000000
#> 2024-05-09 08:21:53,041 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:53,041 EPOCH 9 done: loss 0.7893 - lr: 0.100000
#> 2024-05-09 08:21:54,052 DEV : loss 1.6272923946380615 - f1-score (micro avg)  0.5467
#> 2024-05-09 08:21:54,681  - 4 epochs without improvement (above 'patience')-> annealing learning_rate to [0.05]
#> 2024-05-09 08:21:54,682 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:21:55,434 epoch 10 - iter 6/64 - loss 0.96271145 - time (sec): 0.75 - samples/sec: 255.22 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:21:56,169 epoch 10 - iter 12/64 - loss 0.80516695 - time (sec): 1.49 - samples/sec: 258.15 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:21:56,913 epoch 10 - iter 18/64 - loss 0.75819152 - time (sec): 2.23 - samples/sec: 258.17 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:21:57,740 epoch 10 - iter 24/64 - loss 0.73239921 - time (sec): 3.06 - samples/sec: 251.09 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:21:58,709 epoch 10 - iter 30/64 - loss 0.71001181 - time (sec): 4.03 - samples/sec: 238.39 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:21:59,488 epoch 10 - iter 36/64 - loss 0.69095225 - time (sec): 4.81 - samples/sec: 239.69 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:22:00,271 epoch 10 - iter 42/64 - loss 0.70209982 - time (sec): 5.59 - samples/sec: 240.48 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:22:01,028 epoch 10 - iter 48/64 - loss 0.70447307 - time (sec): 6.35 - samples/sec: 242.04 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:22:01,802 epoch 10 - iter 54/64 - loss 0.70349394 - time (sec): 7.12 - samples/sec: 242.68 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:22:02,520 epoch 10 - iter 60/64 - loss 0.70357497 - time (sec): 7.84 - samples/sec: 244.95 - lr: 0.050000 - momentum: 0.000000
#> 2024-05-09 08:22:03,017 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:03,017 EPOCH 10 done: loss 0.6941 - lr: 0.050000
#> 2024-05-09 08:22:04,043 DEV : loss 0.6851814389228821 - f1-score (micro avg)  0.5689
#> 2024-05-09 08:22:04,674  - 0 epochs without improvement
#> 2024-05-09 08:22:04,955 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:04,956 Loading model from best epoch ...
#> 2024-05-09 08:22:06,230 
#> Results:
#> - F-score (micro) 0.52
#> - F-score (macro) 0.4099
#> - Accuracy 0.52
#> 
#> By class:
#>               precision    recall  f1-score   support
#> 
#>     NEGATIVE     0.5021    0.9835    0.6648       121
#>     POSITIVE     0.8462    0.0853    0.1549       129
#> 
#>     accuracy                         0.5200       250
#>    macro avg     0.6741    0.5344    0.4099       250
#> weighted avg     0.6796    0.5200    0.4017       250
#> 
#> 2024-05-09 08:22:06,230 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.52

Loading and Using the Classifiers

After training the text classification model, the resulting classifier will already be stored in memory as part of the classifier variable. It is possible, however, that your Python session exited after training. If so, you’ll need to load the model into memory with the following:

TextClassifier <- flair_models()$TextClassifier
classifier <- TextClassifier$load('classifier/best-model.pt')

We import the Sentence object. Now, we can generate predictions on some example text inputs.

Sentence <- flair_data()$Sentence
sentence <- Sentence("great")
classifier$predict(sentence)
print(sentence$labels)
#> [[1]]
#> 'Sentence[1]: "great"'/'POSITIVE' (0.9958)
sentence <- Sentence("sad")
classifier$predict(sentence)
print(sentence$labels)
#> [[1]]
#> 'Sentence[1]: "sad"'/'NEGATIVE' (0.8379)

 


Training RNNs

Here, we train a sentiment analysis model to categorize text. In this case, we also include a pipeline that implements the use of Recurrent Neural Networks (RNN). This makes them particularly effective for tasks involving sequential data. This section also show you how to implent one of most powerful feature in featrue, stacked Embeddings. You can stack multiple embeddings with different layers and let the classifier learn from different types of features. In Flair NLP, and with the flaiR package, it’s very easy to accomplish this task.

Import Necessary Modules

library(flaiR)
WordEmbeddings <- flair_embeddings()$WordEmbeddings
FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
DocumentRNNEmbeddings <- flair_embeddings()$DocumentRNNEmbeddings
TextClassifier <- flair_models()$TextClassifier
ModelTrainer <- flair_trainers()$ModelTrainer

Get the IMDB Corpus

The IMDB movie review dataset is used here, which is a commonly utilized dataset for sentiment analysis. $downsample(0.1) method means only 10% of the dataset is used, allowing for a faster demonstration

# load the IMDB file and downsize it to 0.1
IMDB <- flair_datasets()$IMDB
corpus <- IMDB()$downsample(0.1) 
#> 2024-05-09 08:22:06,713 Reading data from /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced
#> 2024-05-09 08:22:06,713 Train: /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced/train.txt
#> 2024-05-09 08:22:06,713 Dev: None
#> 2024-05-09 08:22:06,713 Test: None
#> 2024-05-09 08:22:07,268 No test split found. Using 0% (i.e. 5000 samples) of the train split as test data
#> 2024-05-09 08:22:07,281 No dev split found. Using 0% (i.e. 4500 samples) of the train split as dev data
#> 2024-05-09 08:22:07,281 Initialized corpus /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced (label type name is 'sentiment')
# create the label dictionary
lbl_type <- 'sentiment'
label_dict <- corpus$make_label_dictionary(label_type=lbl_type)
#> 2024-05-09 08:22:07,297 Computing label dictionary. Progress:
#> 2024-05-09 08:22:14,756 Dictionary created for label 'sentiment' with 2 values: POSITIVE (seen 2056 times), NEGATIVE (seen 1994 times)

Stacked Embeddings

This is one of Flair’s most powerful features: it allows for the integration of embeddings to enable the model to learn from more sparse features. Three types of embeddings are utilized here: GloVe embeddings, and two types of Flair embeddings (forward and backward). Word embeddings are used to convert words into vectors.

# make a list of word embeddings
word_embeddings <- list(WordEmbeddings('glove'),
                        FlairEmbeddings('news-forward-fast'),
                        FlairEmbeddings('news-backward-fast'))

# initialize the document embeddings
document_embeddings <- DocumentRNNEmbeddings(word_embeddings, 
                                             hidden_size = 512L,
                                             reproject_words = TRUE,
                                             reproject_words_dimension = 256L)
# create a Text Classifier with the embeddings and label dictionary
classifier <- TextClassifier(document_embeddings, 
                            label_dictionary=label_dict, label_type='class')

# initialize the text classifier trainer with our corpus
trainer <- ModelTrainer(classifier, corpus)

Start the Training

For the sake of this example, setting max_epochs to 5. You might want to increase this for better performance.

It is worth noting that thelearning rate is a parameter that determines the step size at each iteration while moving towards a minimum of the loss function. A smaller learning rate could slow down the learning process, but it could lead to more precise convergence. mini_batch_size determines the number of samples that will be used to compute the gradient at each step. The ‘L’ in 32L is used in R to denote that the number is an integer.

patience (aka early stop) is a hyperparameter used in conjunction with early stopping to avoid overfitting. It determines the number of epochs the training process will tolerate without improvements before stopping the training. Setting max_epochs to 5 means the algorithm will make five passes through the dataset.

# note: the 'L' in 32L is used in R to denote that the number is an integer.
trainer$train('models/sentiment',
              learning_rate=0.1,
              mini_batch_size=32L,
              patience=5L,
              max_epochs=5L)  

To Apply the Trained Model for Prediction

sentence <- "This movie was really exciting!"
classifier$predict(sentence)
print(sentence.labels)

 


Finetune Transformers

We use data from The Temporal Focus of Campaign Communication (2020 JOP) as an example. Let’s assume we receive the data for training from different times. First, suppose you have a dataset of 1000 entries called cc_muller_old. On another day, with the help of nice friends, you receive another set of data, adding 2000 entries in a dataset called cc_muller_new. Both subsets are from data(cc_muller). We will show how to fine-tune a transformer model with cc_muller_old, and then continue with another round of fine-tuning using cc_muller_new

library(flaiR)

Fine-tuning a Transformers Model

Step 1 Load Necessary Modules from Flair

Load necessary classes from flair package.

# Sentence is a class for holding a text sentence
Sentence <- flair_data()$Sentence

# Corpus is a class for text corpora
Corpus <- flair_data()$Corpus

# TransformerDocumentEmbeddings is a class for loading transformer 
TransformerDocumentEmbeddings <- flair_embeddings()$TransformerDocumentEmbeddings

# TextClassifier is a class for text classification
TextClassifier <- flair_models()$TextClassifier

# ModelTrainer is a class for training and evaluating models
ModelTrainer <- flair_trainers()$ModelTrainer

We use purrr to help us split sentences using Sentence from flair_data(), then use map2 to add labels, and finally use Corpus to segment the data.

library(purrr)

data(cc_muller)
cc_muller_old <- cc_muller[1:1000,]

old_text <- map(cc_muller_old$text, Sentence)
old_labels <- as.character(cc_muller_old$class)

old_text <- map2(old_text, old_labels, ~ {
   
  .x$add_label("classification", .y)
  .x
})
print(length(old_text))
#> [1] 1000
set.seed(2046)
sample <- sample(c(TRUE, FALSE), length(old_text), replace=TRUE, prob=c(0.8, 0.2))
old_train  <- old_text[sample]
old_test   <- old_text[!sample]

test_id <- sample(c(TRUE, FALSE), length(old_test), replace=TRUE, prob=c(0.5, 0.5))
old_test   <- old_test[test_id]
old_dev   <- old_test[!test_id]

If you do not provide a development split (dev split) while using Flair, it will automatically split the training data into training and development datasets. The test set is used for training the model and evaluating its final performance, whereas the development set (dev set) is used for adjusting model parameters and preventing overfitting, or in other words, for early stopping of the model.

old_corpus <- Corpus(train = old_train, test = old_test)
#> 2024-05-09 08:22:17,035 No dev split found. Using 0% (i.e. 80 samples) of the train split as dev data

Step 3 Load distilbert Transformer

document_embeddings <- TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=TRUE)

First, $make_label_dictionary function is used to automatically create a label dictionary for the classification task. The label dictionary is a mapping from label to index, which is used to map the labels to a tensor of label indices. expcept classifcation task, flair also supports other label types for training custom model, such as ner, pos and sentiment. From the cc_muller dataset: Future (seen 423 times), Present (seen 262 times), Past (seen 131 times)

old_label_dict <- old_corpus$make_label_dictionary(label_type="classification")
#> 2024-05-09 08:22:18,287 Computing label dictionary. Progress:
#> 2024-05-09 08:22:18,293 Dictionary created for label 'classification' with 3 values: Future (seen 380 times), Present (seen 232 times), Past (seen 111 times)

TextClassifier is used to create a text classifier. The classifier takes the document embeddings (importing from 'distilbert-base-uncased' from HugginFace) and the label dictionary as input. The label type is also specified as classification.

old_classifier <- TextClassifier(document_embeddings,
                                 label_dictionary = old_label_dict, 
                                 label_type='classification')

Step 4 Start Training

ModelTrainer is used to train the model.

old_trainer <- ModelTrainer(model = old_classifier, corpus = old_corpus)
old_trainer$train("vignettes/inst/muller-campaign-communication",  
                  learning_rate=0.02,              
                  mini_batch_size=8L,              
                  anneal_with_restarts = TRUE,
                  save_final_model=TRUE,
                  max_epochs=1L)   
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Model: "TextClassifier(
#>   (embeddings): TransformerDocumentEmbeddings(
#>     (model): DistilBertModel(
#>       (embeddings): Embeddings(
#>         (word_embeddings): Embedding(30523, 768)
#>         (position_embeddings): Embedding(512, 768)
#>         (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>         (dropout): Dropout(p=0.1, inplace=False)
#>       )
#>       (transformer): Transformer(
#>         (layer): ModuleList(
#>           (0-5): 6 x TransformerBlock(
#>             (attention): MultiHeadSelfAttention(
#>               (dropout): Dropout(p=0.1, inplace=False)
#>               (q_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (k_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (v_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (out_lin): Linear(in_features=768, out_features=768, bias=True)
#>             )
#>             (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>             (ffn): FFN(
#>               (dropout): Dropout(p=0.1, inplace=False)
#>               (lin1): Linear(in_features=768, out_features=3072, bias=True)
#>               (lin2): Linear(in_features=3072, out_features=768, bias=True)
#>               (activation): GELUActivation()
#>             )
#>             (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>           )
#>         )
#>       )
#>     )
#>   )
#>   (decoder): Linear(in_features=768, out_features=3, bias=True)
#>   (dropout): Dropout(p=0.0, inplace=False)
#>   (locked_dropout): LockedDropout(p=0.0)
#>   (word_dropout): WordDropout(p=0.0)
#>   (loss_function): CrossEntropyLoss()
#>   (weights): None
#>   (weight_tensor) None
#> )"
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Corpus: 723 train + 80 dev + 85 test sentences
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Train:  723 sentences
#> 2024-05-09 08:22:18,453         (train_with_dev=False, train_with_test=False)
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Training Params:
#> 2024-05-09 08:22:18,453  - learning_rate: "0.02" 
#> 2024-05-09 08:22:18,453  - mini_batch_size: "8"
#> 2024-05-09 08:22:18,453  - max_epochs: "1"
#> 2024-05-09 08:22:18,453  - shuffle: "True"
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Plugins:
#> 2024-05-09 08:22:18,453  - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-05-09 08:22:18,453 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,453 Final evaluation on model from best epoch (best-model.pt)
#> 2024-05-09 08:22:18,453  - metric: "('micro avg', 'f1-score')"
#> 2024-05-09 08:22:18,454 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,454 Computation:
#> 2024-05-09 08:22:18,454  - compute on device: cpu
#> 2024-05-09 08:22:18,454  - embedding storage: cpu
#> 2024-05-09 08:22:18,454 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,454 Model training base path: "vignettes/inst/muller-campaign-communication"
#> 2024-05-09 08:22:18,454 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:18,454 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:21,258 epoch 1 - iter 9/91 - loss 1.11748180 - time (sec): 2.80 - samples/sec: 25.68 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:23,784 epoch 1 - iter 18/91 - loss 1.00897902 - time (sec): 5.33 - samples/sec: 27.01 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:25,944 epoch 1 - iter 27/91 - loss 0.96961984 - time (sec): 7.49 - samples/sec: 28.84 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:28,274 epoch 1 - iter 36/91 - loss 0.91258188 - time (sec): 9.82 - samples/sec: 29.33 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:30,353 epoch 1 - iter 45/91 - loss 0.83585964 - time (sec): 11.90 - samples/sec: 30.25 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:32,487 epoch 1 - iter 54/91 - loss 0.81807950 - time (sec): 14.03 - samples/sec: 30.78 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:34,807 epoch 1 - iter 63/91 - loss 0.75969536 - time (sec): 16.35 - samples/sec: 30.82 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:36,923 epoch 1 - iter 72/91 - loss 0.72803484 - time (sec): 18.47 - samples/sec: 31.19 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:39,125 epoch 1 - iter 81/91 - loss 0.69332746 - time (sec): 20.67 - samples/sec: 31.35 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:41,195 epoch 1 - iter 90/91 - loss 0.67458037 - time (sec): 22.74 - samples/sec: 31.66 - lr: 0.020000 - momentum: 0.000000
#> 2024-05-09 08:22:41,309 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:41,309 EPOCH 1 done: loss 0.6730 - lr: 0.020000
#> 2024-05-09 08:22:42,054 DEV : loss 0.41487812995910645 - f1-score (micro avg)  0.875
#> 2024-05-09 08:22:42,056  - 0 epochs without improvement
#> 2024-05-09 08:22:42,056 saving best model
#> 2024-05-09 08:22:42,753 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:42,755 Loading model from best epoch ...
#> 2024-05-09 08:22:44,617 
#> Results:
#> - F-score (micro) 0.8471
#> - F-score (macro) 0.8385
#> - Accuracy 0.8471
#> 
#> By class:
#>               precision    recall  f1-score   support
#> 
#>       Future     0.8511    0.9302    0.8889        43
#>      Present     0.8261    0.7037    0.7600        27
#>         Past     0.8667    0.8667    0.8667        15
#> 
#>     accuracy                         0.8471        85
#>    macro avg     0.8479    0.8335    0.8385        85
#> weighted avg     0.8459    0.8471    0.8440        85
#> 
#> 2024-05-09 08:22:44,617 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.8470588

Continue Fine-tuning with New Dataset

Now, we can continue to fine-tune the already fine-tuned model with an additional 2000 pieces of data. First, let’s say we have another 2000 entries called cc_muller_new. We can fine-tune the previous model with these 2000 entries. The steps are the same as before. For this case, we don’t need to split the dataset again. We can use the entire 2000 entries as the training set and use the old_test set to evaluate how well our refined model performs.

Step 1 Load the muller-campaign-communication Model

Load the model (old_model) you have already finetuned from previous stage and let’s fine-tune it with the new data, new_corpus.

old_model <- TextClassifier$load("vignettes/inst/muller-campaign-communication/best-model.pt")

Step 2 Convert the New Data to Sentence and Corpus

library(purrr)
cc_muller_new <- cc_muller[1001:3000,]
new_text <- map(cc_muller_new$text, Sentence)
new_labels <- as.character(cc_muller_new$class)

new_text <- map2(new_text, new_labels, ~ {
  .x$add_label("classification", .y)
  .x
})
new_corpus <- Corpus(train=new_text, test=old_test)
#> 2024-05-09 08:22:46,356 No dev split found. Using 0% (i.e. 200 samples) of the train split as dev data

Step 3 Create a New Model Trainer with the Old Model and New Corpus

new_trainer <- ModelTrainer(old_model, new_corpus)

Step 4 Train the New Model

new_trainer$train("vignettes/inst/new-muller-campaign-communication",
                  learning_rate=0.002, 
                  mini_batch_size=8L,  
                  max_epochs=1L)    
#> 2024-05-09 08:22:46,450 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Model: "TextClassifier(
#>   (embeddings): TransformerDocumentEmbeddings(
#>     (model): DistilBertModel(
#>       (embeddings): Embeddings(
#>         (word_embeddings): Embedding(30523, 768, padding_idx=0)
#>         (position_embeddings): Embedding(512, 768)
#>         (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>         (dropout): Dropout(p=0.1, inplace=False)
#>       )
#>       (transformer): Transformer(
#>         (layer): ModuleList(
#>           (0-5): 6 x TransformerBlock(
#>             (attention): MultiHeadSelfAttention(
#>               (dropout): Dropout(p=0.1, inplace=False)
#>               (q_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (k_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (v_lin): Linear(in_features=768, out_features=768, bias=True)
#>               (out_lin): Linear(in_features=768, out_features=768, bias=True)
#>             )
#>             (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>             (ffn): FFN(
#>               (dropout): Dropout(p=0.1, inplace=False)
#>               (lin1): Linear(in_features=768, out_features=3072, bias=True)
#>               (lin2): Linear(in_features=3072, out_features=768, bias=True)
#>               (activation): GELUActivation()
#>             )
#>             (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#>           )
#>         )
#>       )
#>     )
#>   )
#>   (decoder): Linear(in_features=768, out_features=3, bias=True)
#>   (dropout): Dropout(p=0.0, inplace=False)
#>   (locked_dropout): LockedDropout(p=0.0)
#>   (word_dropout): WordDropout(p=0.0)
#>   (loss_function): CrossEntropyLoss()
#>   (weights): None
#>   (weight_tensor) None
#> )"
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Corpus: 1800 train + 200 dev + 85 test sentences
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Train:  1800 sentences
#> 2024-05-09 08:22:46,451         (train_with_dev=False, train_with_test=False)
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Training Params:
#> 2024-05-09 08:22:46,451  - learning_rate: "0.002" 
#> 2024-05-09 08:22:46,451  - mini_batch_size: "8"
#> 2024-05-09 08:22:46,451  - max_epochs: "1"
#> 2024-05-09 08:22:46,451  - shuffle: "True"
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Plugins:
#> 2024-05-09 08:22:46,451  - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Final evaluation on model from best epoch (best-model.pt)
#> 2024-05-09 08:22:46,451  - metric: "('micro avg', 'f1-score')"
#> 2024-05-09 08:22:46,451 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,451 Computation:
#> 2024-05-09 08:22:46,451  - compute on device: cpu
#> 2024-05-09 08:22:46,452  - embedding storage: cpu
#> 2024-05-09 08:22:46,452 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,452 Model training base path: "vignettes/inst/new-muller-campaign-communication"
#> 2024-05-09 08:22:46,452 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:46,452 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:22:51,771 epoch 1 - iter 22/225 - loss 0.44434727 - time (sec): 5.32 - samples/sec: 33.09 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:22:57,917 epoch 1 - iter 44/225 - loss 0.44069791 - time (sec): 11.46 - samples/sec: 30.70 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:02,973 epoch 1 - iter 66/225 - loss 0.41245846 - time (sec): 16.52 - samples/sec: 31.96 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:08,369 epoch 1 - iter 88/225 - loss 0.38905306 - time (sec): 21.92 - samples/sec: 32.12 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:14,113 epoch 1 - iter 110/225 - loss 0.38627105 - time (sec): 27.66 - samples/sec: 31.81 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:19,490 epoch 1 - iter 132/225 - loss 0.39458834 - time (sec): 33.04 - samples/sec: 31.96 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:25,127 epoch 1 - iter 154/225 - loss 0.39355057 - time (sec): 38.67 - samples/sec: 31.86 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:31,204 epoch 1 - iter 176/225 - loss 0.36979927 - time (sec): 44.75 - samples/sec: 31.46 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:37,076 epoch 1 - iter 198/225 - loss 0.37568280 - time (sec): 50.62 - samples/sec: 31.29 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:42,106 epoch 1 - iter 220/225 - loss 0.37887175 - time (sec): 55.65 - samples/sec: 31.62 - lr: 0.002000 - momentum: 0.000000
#> 2024-05-09 08:23:43,270 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:23:43,270 EPOCH 1 done: loss 0.3757 - lr: 0.002000
#> 2024-05-09 08:23:45,330 DEV : loss 0.38609111309051514 - f1-score (micro avg)  0.88
#> 2024-05-09 08:23:45,333  - 0 epochs without improvement
#> 2024-05-09 08:23:45,334 saving best model
#> 2024-05-09 08:23:45,926 ----------------------------------------------------------------------------------------------------
#> 2024-05-09 08:23:45,930 Loading model from best epoch ...
#> 2024-05-09 08:23:47,911 
#> Results:
#> - F-score (micro) 0.8471
#> - F-score (macro) 0.8583
#> - Accuracy 0.8471
#> 
#> By class:
#>               precision    recall  f1-score   support
#> 
#>       Future     0.8605    0.8605    0.8605        43
#>      Present     0.7586    0.8148    0.7857        27
#>         Past     1.0000    0.8667    0.9286        15
#> 
#>     accuracy                         0.8471        85
#>    macro avg     0.8730    0.8473    0.8583        85
#> weighted avg     0.8527    0.8471    0.8487        85
#> 
#> 2024-05-09 08:23:47,911 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.8470588

More R tutorial and documentation see here.