Tutorial
Yen-Chieh (David) Liao | Sohini Timbadia | Stefan Müller
University of Birmingham & University College DublinSource:
vignettes/tutorial.Rmd
tutorial.Rmd
The Overview
Flair NLP is an open-source library for Natural Language Processing (NLP) developed by Zalando Research. Known for its state-of-the-art solutions, such as contextual string embeddings for NLP tasks like Named Entity Recognition (NER), Part-of-Speech tagging (POS), and more, it has garnered the attention of the NLP community for its ease of use and powerful functionalities.
In addition, Flair NLP offers pre-trained models for various languages and tasks, and is compatible with fine-tuned transformers hosted on Hugging Face.
Sentence and Token
Sentence and Token are fundamental classes.
Sentence
A Sentence in Flair is an object that contains a sequence of Token objects, and it can be annotated with labels, such as named entities, part-of-speech tags, and more. It also can store embeddings for the sentence as a whole and different kinds of linguistic annotations.
Here’s a simple example of how you create a Sentence:
# Creating a Sentence object
library(flaiR)
string <- "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland."
Sentence <- flair_data()$Sentence
sentence <- Sentence(string)
Sentence[26]
means that there are a total of 26 tokens
in the sentence.
print(sentence)
#> Sentence[26]: "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland."
Token
When you use Flair to handle text data,1 Sentence
and Token
objects often play central roles in many use
cases. When you create a Sentence object, it automatically tokenizes the
text, removing the need to create the Token object manually.
Unlike R, which indexes from 1, Python indexes from 0. Therefore,
when using a for loop, I use seq_along(sentence) - 1
. The
output should be something like:
# The Sentence object has automatically created and contains multiple Token objects
# We can iterate through the Sentence object to view each Token
for (i in seq_along(sentence)-1) {
print(sentence[[i]])
}
#> Token[0]: "What"
#> Token[1]: "I"
#> Token[2]: "see"
#> Token[3]: "in"
#> Token[4]: "UCD"
#> Token[5]: "today"
#> Token[6]: ","
#> Token[7]: "what"
#> Token[8]: "I"
#> Token[9]: "have"
#> Token[10]: "seen"
#> Token[11]: "of"
#> Token[12]: "UCD"
#> Token[13]: "in"
#> Token[14]: "its"
#> Token[15]: "impact"
#> Token[16]: "on"
#> Token[17]: "my"
#> Token[18]: "own"
#> Token[19]: "life"
#> Token[20]: "and"
#> Token[21]: "the"
#> Token[22]: "life"
#> Token[23]: "of"
#> Token[24]: "Ireland"
#> Token[25]: "."
Or you can directly use $tokens
method to print all
tokens.
print(sentence$tokens)
#> [[1]]
#> Token[0]: "What"
#>
#> [[2]]
#> Token[1]: "I"
#>
#> [[3]]
#> Token[2]: "see"
#>
#> [[4]]
#> Token[3]: "in"
#>
#> [[5]]
#> Token[4]: "UCD"
#>
#> [[6]]
#> Token[5]: "today"
#>
#> [[7]]
#> Token[6]: ","
#>
#> [[8]]
#> Token[7]: "what"
#>
#> [[9]]
#> Token[8]: "I"
#>
#> [[10]]
#> Token[9]: "have"
#>
#> [[11]]
#> Token[10]: "seen"
#>
#> [[12]]
#> Token[11]: "of"
#>
#> [[13]]
#> Token[12]: "UCD"
#>
#> [[14]]
#> Token[13]: "in"
#>
#> [[15]]
#> Token[14]: "its"
#>
#> [[16]]
#> Token[15]: "impact"
#>
#> [[17]]
#> Token[16]: "on"
#>
#> [[18]]
#> Token[17]: "my"
#>
#> [[19]]
#> Token[18]: "own"
#>
#> [[20]]
#> Token[19]: "life"
#>
#> [[21]]
#> Token[20]: "and"
#>
#> [[22]]
#> Token[21]: "the"
#>
#> [[23]]
#> Token[22]: "life"
#>
#> [[24]]
#> Token[23]: "of"
#>
#> [[25]]
#> Token[24]: "Ireland"
#>
#> [[26]]
#> Token[25]: "."
Retrieve the Token
To comprehend the string representation format of the Sentence
object, tagging at least one token is adequate. Python’s
get_token(n)
method allows us to retrieve the Token object
for a particular token. Additionally, we can use
[]
to index a specific token.
# method in Python
sentence$get_token(5)
#> Token[4]: "UCD"
# indexing in R
sentence[6]
#> Token[6]: ","
Each word (and punctuation) in the text is treated as an individual Token object. These Token objects store text information and other possible linguistic information (such as part-of-speech tags or named entity tags) and embeddings (if you used a model to generate them).
While you do not need to create Token objects manually, understanding how to manage them is useful in situations where you might want to fine-tune the tokenization process. For example, you can control the exactness of tokenization by manually creating Token objects from a Sentence object.
This makes Flair very flexible when handling text data since the automatic tokenization feature can be used for rapid development, while also allowing users to fine-tune their tokenization.
Annotate POS tag and NER tag
The add_label(label_type, value)
method can be employed
to assign a label to the token. In Universal POS tags, if
sentence[10]
is ‘see’, ‘seen’ might be tagged as
VERB
, indicating it is a past participle form of a
verb.
sentence[10]$add_label('manual-pos', 'VERB')
print(sentence[10])
#> Token[10]: "seen" → VERB (1.0000)
We can also add a NER (Named Entity Recognition) tag to
sentence[4]
, “UCD”, identifying it as a university in
Dublin.
sentence[4]$add_label('ner', 'ORG')
print(sentence[4])
#> Token[4]: "UCD" → ORG (1.0000)
If we print the sentence object, Sentence[50]
provides
information for 50 tokens → [‘in’/ORG, ‘seen’/VERB], thus displaying two
tagging pieces of information.
print(sentence)
#> Sentence[26]: "What I see in UCD today, what I have seen of UCD in its impact on my own life and the life of Ireland." → ["UCD"/ORG, "seen"/VERB]
Corpus
The Corpus object in Flair is a fundamental data structure that represents a dataset containing text samples, usually comprising of a training set, a development set (or validation set), and a test set. It’s designed to work smoothly with Flair’s models for tasks like named entity recognition, text classification, and more.
Attributes:
-
train
: A list of sentences (List[Sentence]) that form the training dataset. -
dev
(or development): A list of sentences (List[Sentence]) that form the development (or validation) dataset. -
test
: A list of sentences (List[Sentence]) that form the test dataset.
Important Methods:
-
downsample
: This method allows you to downsample (reduce) the number of sentences in the train, dev, and test splits. -
obtain_statistics
: This method gives a quick overview of the statistics of the corpus, including the number of sentences and the distribution of labels. -
make_vocab_dictionary
: Used to create a vocabulary dictionary from the corpus.
library(flaiR)
Corpus <- flair_data()$Corpus
Sentence <- flair_data()$Sentence
# Create some example sentences
train <- list(Sentence('This is a training example.'))
dev <- list(Sentence('This is a validation example.'))
test <- list(Sentence('This is a test example.'))
# Create a corpus using the custom data splits
corp <- Corpus(train = train, dev = dev, test = test)
$obtain_statistics()
method of the Corpus object in the
Flair library provides an overview of the dataset statistics. The method
returns a Python
dictionary with details about the training, validation
(development), and test datasets that make up the corpus. In R, you can
use the jsonlite package to format JSON.
library(jsonlite)
data <- fromJSON(corp$obtain_statistics())
formatted_str <- toJSON(data, pretty=TRUE)
print(formatted_str)
#> {
#> "TRAIN": {
#> "dataset": ["TRAIN"],
#> "total_number_of_documents": [1],
#> "number_of_documents_per_class": {},
#> "number_of_tokens_per_tag": {},
#> "number_of_tokens": {
#> "total": [6],
#> "min": [6],
#> "max": [6],
#> "avg": [6]
#> }
#> },
#> "TEST": {
#> "dataset": ["TEST"],
#> "total_number_of_documents": [1],
#> "number_of_documents_per_class": {},
#> "number_of_tokens_per_tag": {},
#> "number_of_tokens": {
#> "total": [6],
#> "min": [6],
#> "max": [6],
#> "avg": [6]
#> }
#> },
#> "DEV": {
#> "dataset": ["DEV"],
#> "total_number_of_documents": [1],
#> "number_of_documents_per_class": {},
#> "number_of_tokens_per_tag": {},
#> "number_of_tokens": {
#> "total": [6],
#> "min": [6],
#> "max": [6],
#> "avg": [6]
#> }
#> }
#> }
In R
Below, we use data from the article The Temporal Focus of Campaign Communication by Stefan Muller, published in the Journal of Politics in 2020, as an example.
First, we vectorize the cc_muller$text
using the
Sentence function to transform it into a list object. Then, we reformat
cc_muller$class_pro_retro
as a factor. It’s essential to
note that R handles numerical values differently than Python. In R,
numerical values are represented with a floating point, so it’s
advisable to convert them into factors or strings. Lastly, we employ the
map function from the purrr package to assign labels to each sentence
corpus using the $add_label
method.
library(purrr)
#>
#> Attaching package: 'purrr'
#> The following object is masked from 'package:jsonlite':
#>
#> flatten
data(cc_muller)
# The `Sentence` object tokenizes text
text <- lapply( cc_muller$text, Sentence)
# split sentence object to train and test.
labels <- as.factor(cc_muller$class_pro_retro)
# `$add_label` method assigns the corresponding coded type to each Sentence corpus.
text <- map2(text, labels, ~ .x$add_label("classification", .y), .progress = TRUE)
To perform a train-test split using base R, we can follow these steps:
set.seed(2046)
sample <- sample(c(TRUE, FALSE), length(text), replace=TRUE, prob=c(0.8, 0.2))
train <- text[sample]
test <- text[!sample]
sprintf("Corpus object sizes - Train: %d | Test: %d", length(train), length(test))
#> [1] "Corpus object sizes - Train: 4710 | Test: 1148"
If you don’t provide a dev set, flaiR will not force you to carve out a portion of your test set to serve as a dev set. However, in some cases when only the train and test sets are provided without a dev set, flaiR might automatically take a fraction of the train set (e.g., 10%) to use as a dev set (#2259). This is to offer a mechanism for model selection and to prevent the model from overfitting on the train set.
In the “Corpus” function, there is a random selection of the “dev”
dataset. To ensure reproducibility, we need to set the seed in the flaiR
framework. We can accomplish this by calling the top-level module
“flair” from {flaiR} and using $set_seed(1964L)
to set the
seed.
flair <- import_flair()
flair$set_seed(1964L)
corp <- Corpus(train=train,
# dev=test,
test=test)
#> 2024-11-28 20:56:22,882 No dev split found. Using 10% (i.e. 471 samples) of the train split as dev data
sprintf("Corpus object sizes - Train: %d | Test: %d | Dev: %d",
length(corp$train),
length(corp$test),
length(corp$dev))
#> [1] "Corpus object sizes - Train: 4239 | Test: 1148 | Dev: 471"
In the later sections, there will be more similar processing using
the Corpus
. Following that, we will focus on advanced NLP
applications.
Sequence Taggings
Tag Entities in Text
Let’s run named entity recognition over the following example sentence: “I love Berlin and New York”. To do this, all you need to do is make a Sentence object for this text, load a pre-trained model and use it to predict tags for the object.
# attach flaiR in R
library(flaiR)
# make a sentence
Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')
# load the NER tagger
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('ner')
#> 2024-11-28 20:56:24,174 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>
# run NER over sentence
tagger$predict(sentence)
To print all annotations:
# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → ["Berlin"/LOC, "New York"/LOC]
Use a for loop to print out each POS tag. It’s important to note that
Python is indexed from 0. Therefore, in an R environment, we must use
seq_along(sentence$get_labels()) - 1
.
Tag Part-of-Speech in Text
We use flaiR/POS-english for POS tagging in the standard models on Hugging Face.
# attach flaiR in R
library(flaiR)
# make a sentence
Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')
# load the NER tagger
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('pos')
#> 2024-11-28 20:56:25,184 SequenceTagger predicts: Dictionary with 53 tags: <unk>, O, UH, ,, VBD, PRP, VB, PRP$, NN, RB, ., DT, JJ, VBP, VBG, IN, CD, NNS, NNP, WRB, VBZ, WDT, CC, TO, MD, VBN, WP, :, RP, EX, JJR, FW, XX, HYPH, POS, RBR, JJS, PDT, NNPS, RBS, AFX, WP$, -LRB-, -RRB-, ``, '', LS, $, SYM, ADD
# run NER over sentence
tagger$predict(sentence)
To print all annotations:
# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → ["I"/PRP, "love"/VBP, "Berlin"/NNP, "and"/CC, "New"/NNP, "York"/NNP, "."/.]
Use a for loop to print out each POS tag.
for (i in seq_along(sentence$get_labels())) {
print(sentence$get_labels()[[i]])
}
#> 'Token[0]: "I"'/'PRP' (1.0)
#> 'Token[1]: "love"'/'VBP' (1.0)
#> 'Token[2]: "Berlin"'/'NNP' (0.9999)
#> 'Token[3]: "and"'/'CC' (1.0)
#> 'Token[4]: "New"'/'NNP' (1.0)
#> 'Token[5]: "York"'/'NNP' (1.0)
#> 'Token[6]: "."'/'.' (1.0)
Detect Sentiment
Let’s run sentiment analysis over the same sentence to determine whether it is POSITIVE or NEGATIVE.
You can do this with essentially the same code as above. Instead of loading the ‘ner’ model, you now load the ‘sentiment’ model:
# attach flaiR in R
library(flaiR)
# make a sentence
Sentence <- flair_data()$Sentence
sentence <- Sentence('I love Berlin and New York.')
# load the Classifier tagger from flair.nn module
Classifier <- flair_nn()$Classifier
tagger <- Classifier$load('sentiment')
# run sentiment analysis over sentence
tagger$predict(sentence)
# print the sentence with all annotations
print(sentence)
#> Sentence[7]: "I love Berlin and New York." → POSITIVE (0.9982)
Tagging Parts-of-Speech with Flair Models
You can load the pre-trained model "pos-fast"
. For more
pre-trained models, see https://flairnlp.github.io/docs/tutorial-basics/part-of-speech-tagging#-in-english.
texts <- c("UCD is one of the best universities in Ireland.",
"UCD has a good campus but is very far from my apartment in Dublin.",
"Essex is famous for social science research.",
"Essex is not in the Russell Group, but it is famous for political science research and in 1994 Group.",
"TCD is the oldest university in Ireland.",
"TCD is similar to Oxford.")
doc_ids <- c("doc1", "doc2", "doc3", "doc4", "doc5", "doc6")
tagger_pos <- load_tagger_pos("pos-fast")
#> 2024-11-28 20:56:28,696 SequenceTagger predicts: Dictionary with 53 tags: <unk>, O, UH, ,, VBD, PRP, VB, PRP$, NN, RB, ., DT, JJ, VBP, VBG, IN, CD, NNS, NNP, WRB, VBZ, WDT, CC, TO, MD, VBN, WP, :, RP, EX, JJR, FW, XX, HYPH, POS, RBR, JJS, PDT, NNPS, RBS, AFX, WP$, -LRB-, -RRB-, ``, '', LS, $, SYM, ADD
results <- get_pos(texts, doc_ids, tagger_pos)
head(results, n = 10)
#> doc_id token_id text_id token tag precision
#> <char> <num> <lgcl> <char> <char> <num>
#> 1: doc1 0 NA UCD NNP 0.9967
#> 2: doc1 1 NA is VBZ 1.0000
#> 3: doc1 2 NA one CD 0.9993
#> 4: doc1 3 NA of IN 1.0000
#> 5: doc1 4 NA the DT 1.0000
#> 6: doc1 5 NA best JJS 0.9988
#> 7: doc1 6 NA universities NNS 0.9997
#> 8: doc1 7 NA in IN 1.0000
#> 9: doc1 8 NA Ireland NNP 1.0000
#> 10: doc1 9 NA . . 0.9998
Tagging Entities with Flair Models
Load the pre-trained model ner
. For more pre-trained
models, see https://flairnlp.github.io/docs/tutorial-basics/tagging-entities.
tagger_ner <- load_tagger_ner("ner")
#> 2024-11-28 20:56:29,504 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>
results <- get_entities(texts, doc_ids, tagger_ner)
head(results, n = 10)
#> doc_id entity tag
#> <char> <char> <char>
#> 1: doc1 UCD ORG
#> 2: doc1 Ireland LOC
#> 3: doc2 UCD ORG
#> 4: doc2 Dublin LOC
#> 5: doc3 Essex ORG
#> 6: doc4 Essex ORG
#> 7: doc4 Russell Group ORG
#> 8: doc5 TCD ORG
#> 9: doc5 Ireland LOC
#> 10: doc6 TCD ORG
Tagging Sentiment
Load the pre-trained model “sentiment
”. The pre-trained
models of “sentiment
”, “sentiment-fast
”, and
“de-offensive-language
” are currently available. For more
pre-trained models, see https://flairnlp.github.io/docs/tutorial-basics/tagging-sentiment.
tagger_sent <- load_tagger_sentiments("sentiment")
results <- get_sentiments(texts, doc_ids, tagger_sent)
head(results, n = 10)
#> doc_id sentiment score
#> <char> <char> <num>
#> 1: doc1 POSITIVE 0.9970598
#> 2: doc2 NEGATIVE 0.8472342
#> 3: doc3 POSITIVE 0.9928006
#> 4: doc4 POSITIVE 0.9901404
#> 5: doc5 POSITIVE 0.9952670
#> 6: doc6 POSITIVE 0.9291791
Embedding
Flair is a very popular natural language processing library, providing a variety of embedding methods for text representation. Flair Embeddings is a word embedding framework developed by Zalando. It focuses on word-level representation and can capture contextual information of words, allowing the same word to have different embeddings in different contexts. Unlike traditional word embeddings (such as Word2Vec or GloVe), Flair can dynamically generate word embeddings based on context and has achieved excellent results in various NLP tasks. Below are some key points about Flair Embeddings:
Context-Aware
Flair is a dynamic word embedding technique that can understand the meaning of words based on context. In contrast, static word embeddings, such as Word2Vec or GloVe, provide a fixed embedding for each word without considering its context in a sentence.
Therefore, context-sensitive embedding techniques, such as Flair, can capture the meaning of words in specific sentences more accurately, thus enhancing the performance of language models in various tasks.
Example:
Consider the following two English sentences:
- “I am interested in the bank of the river.”
- “I need to go to the bank to withdraw money.”
Here, the word “bank” has two different meanings. In the first sentence, it refers to the edge or shore of a river. In the second sentence, it refers to a financial institution.
For static embeddings, the word “bank” might have an embedding that lies somewhere between these two meanings because it doesn’t consider context. But for dynamic embeddings like Flair, “bank” in the first sentence will have an embedding related to rivers, and in the second sentence, it will have an embedding related to finance.
FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
Sentence <- flair_data()$Sentence
# Initialize Flair embeddings
flair_embedding_forward <- FlairEmbeddings('news-forward')
# Define the two sentences
sentence1 <- Sentence("I am interested in the bank of the river.")
sentence2 <- Sentence("I need to go to the bank to withdraw money.")
# Get the embeddings
flair_embedding_forward$embed(sentence1)
#> [[1]]
#> Sentence[10]: "I am interested in the bank of the river."
flair_embedding_forward$embed(sentence2)
#> [[1]]
#> Sentence[11]: "I need to go to the bank to withdraw money."
# Extract the embedding for "bank" from the sentences
bank_embedding_sentence1 = sentence1[5]$embedding # "bank" is the seventh word
bank_embedding_sentence2 = sentence2[6]$embedding # "bank" is the sixth word
Same word, similar vector representation, but essentially different. In this way, you can see how the dynamic embeddings for “bank” in the two sentences differ based on context. Although we printed the embeddings here, in reality, they would be high-dimensional vectors, so you might see a lot of numbers. If you want a more intuitive view of the differences, you could compute the cosine similarity or other metrics between the two embeddings.
This is just a simple demonstration. In practice, you can also
combine multiple embedding techniques, such as
WordEmbeddings
and FlairEmbeddings
, to get
richer word vectors.
library(lsa)
#> Loading required package: SnowballC
cosine(as.numeric( bank_embedding_sentence1$numpy()),
as.numeric( bank_embedding_sentence2$numpy()))
#> [,1]
#> [1,] 0.7329551
Character-Based
Flair uses a character-level language model, meaning it can generate embeddings for rare words or even misspelled words. This is an important feature because it allows the model to understand and process words that have never appeared in the training data. Flair uses a bidirectional LSTM (Long Short-Term Memory) network that operates at a character level. This allows it to feed individual characters into the LSTM instead of words.
Multilingual Support
Flair provides various pre-trained character-level language models, supporting contextual word embeddings for multiple languages. It allows you to easily combine different word embeddings (e.g., Flair Embeddings, Word2Vec, GloVe, etc.) to create powerful stacked embeddings.
Classic Wordembeddings
In Flair, the simplest form of embeddings that still contains semantic information about the word are called classic word embeddings. These embeddings are pre-trained and non-contextual.
Let’s retrieve a few word embeddings and use FastText embeddings with
the following code. To do so, we simply instantiate a WordEmbeddings
class by passing in the ID of the embedding of our choice. Then, we
simply wrap our text into a Sentence object, and call the
embed(sentence)
method on our WordEmbeddings class.
WordEmbeddings <- flair_embeddings()$WordEmbeddings
Sentence <- flair_data()$Sentence
embedding <- WordEmbeddings('crawl')
sentence <- Sentence("one two three one")
embedding$embed(sentence)
#> [[1]]
#> Sentence[4]: "one two three one"
for (i in seq_along(sentence$tokens)) {
print(head(sentence$tokens[[i]]$embedding), n =5)
}
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486, 0.2383])
#> tensor([ 0.0282, -0.0786, -0.1236, 0.1756, -0.1199, 0.0964])
#> tensor([-0.0920, -0.0690, -0.1475, 0.2313, -0.0872, 0.0799])
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486, 0.2383])
Flair supports a range of classic word embeddings, each offering unique features and application scopes. Below is an overview, detailing the ID required to load each embedding and its corresponding language.
Embedding Type | ID | Language |
---|---|---|
GloVe | glove | English |
Komninos | extvec | English |
English | ||
Turian (small) | turian | English |
FastText (crawl) | crawl | English |
FastText (news & Wikipedia) | ar | Arabic |
FastText (news & Wikipedia) | bg | Bulgarian |
FastText (news & Wikipedia) | ca | Catalan |
FastText (news & Wikipedia) | cz | Czech |
FastText (news & Wikipedia) | da | Danish |
FastText (news & Wikipedia) | de | German |
FastText (news & Wikipedia) | es | Spanish |
FastText (news & Wikipedia) | en | English |
FastText (news & Wikipedia) | eu | Basque |
FastText (news & Wikipedia) | fa | Persian |
FastText (news & Wikipedia) | fi | Finnish |
FastText (news & Wikipedia) | fr | French |
FastText (news & Wikipedia) | he | Hebrew |
FastText (news & Wikipedia) | hi | Hindi |
FastText (news & Wikipedia) | hr | Croatian |
FastText (news & Wikipedia) | id | Indonesian |
FastText (news & Wikipedia) | it | Italian |
FastText (news & Wikipedia) | ja | Japanese |
FastText (news & Wikipedia) | ko | Korean |
FastText (news & Wikipedia) | nl | Dutch |
FastText (news & Wikipedia) | no | Norwegian |
FastText (news & Wikipedia) | pl | Polish |
FastText (news & Wikipedia) | pt | Portuguese |
FastText (news & Wikipedia) | ro | Romanian |
FastText (news & Wikipedia) | ru | Russian |
FastText (news & Wikipedia) | si | Slovenian |
FastText (news & Wikipedia) | sk | Slovak |
FastText (news & Wikipedia) | sr | Serbian |
FastText (news & Wikipedia) | sv | Swedish |
FastText (news & Wikipedia) | tr | Turkish |
FastText (news & Wikipedia) | zh | Chinese |
Contexual Embeddings
The idea behind contextual string embeddings is that each word embedding should be defined by not only its syntactic-semantic meaning but also the context it appears in. What this means is that each word will have a different embedding for every context it appears in. Each pre-trained Flair model offers a forward version and a backward version. Let’s assume you are processing a language that, just like this text, uses the left-to-right script. The forward version takes into account the context that happens before the word – on the left-hand side. The backward version works in the opposite direction. It takes into account the context after the word – on the right-hand side of the word. If this is true, then two same words that appear at the beginning of two different sentences should have identical forward embeddings, because their context is null. Let’s test this out:
Because we are using a forward model, it only takes into account the context that occurs before a word. Additionally, since our word has no context on the left-hand side of its position in the sentence, the two embeddings are identical, and the code assumes they are identical, indeed output is True.
FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
embedding <- FlairEmbeddings('news-forward')
s1 <- Sentence("nice shirt")
s2 <- Sentence("nice pants")
embedding$embed(s1)
#> [[1]]
#> Sentence[2]: "nice shirt"
embedding$embed(s2)
#> [[1]]
#> Sentence[2]: "nice pants"
cat(" s1 sentence:", paste(s1[0], sep = ""), "\n", "s2 sentence:", paste(s2[0], sep = ""))
#> s1 sentence: Token[0]: "nice"
#> s2 sentence: Token[0]: "nice"
We test whether the sum of the two 2048 embeddings of
nice
is equal to 2048. If it is true, it indicates that the
embedding results are consistent, which should theoretically be the
case.
length(s1[0]$embedding$numpy()) == sum(s1[0]$embedding$numpy() == s2[0]$embedding$numpy())
#> [1] TRUE
Now we separately add a few more words, very
and
pretty
, into two sentence objects.
s1 <- Sentence("very nice shirt")
s2 <- Sentence("pretty nice pants")
embedding$embed(s1)
#> [[1]]
#> Sentence[3]: "very nice shirt"
embedding$embed(s2)
#> [[1]]
#> Sentence[3]: "pretty nice pants"
The two sets of embeddings are not identical because the words are different, so it returns FALSE.
length(s1[0]$embedding$numpy()) == sum(s1[0]$embedding$numpy() == s2[0]$embedding$numpy())
#> [1] FALSE
The measure of similarity between two vectors in an inner product space is known as cosine similarity. The formula for calculating cosine similarity between two vectors, such as vectors A and B, is as follows:
library(lsa)
vector1 <- as.numeric(s1[0]$embedding$numpy())
vector2 <- as.numeric(s2[0]$embedding$numpy())
We can observe that the similarity between the two words is 0.55.
Extracting Embeddings from BERT
First, we utilize the
flair.embeddings.TransformerWordEmbeddings
function to
download BERT, and more transformer models can also be found on Flair NLP’s Hugging Face.
library(flaiR)
TransformerWordEmbeddings <- flair_embeddings()$TransformerWordEmbeddings("bert-base-uncased")
embedding <- TransformerWordEmbeddings$embed(sentence)
Next, we traverse each token in the sentence and print them.
# Iterate through each token in the sentence, printing them.
# Utilize reticulate::py_str(token) to view each token, given that the sentence is a Python object.
for (i in seq_along(sentence$tokens)) {
cat("Token: ", reticulate::py_str(sentence$tokens[[i]]), "\n")
# Access the embedding of the token, converting it to an R object,
# and print the first 10 elements of the vector.
token_embedding <- sentence$tokens[[i]]$embedding
print(head(token_embedding, 10))
}
#> Token: Token[0]: "one"
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486, 0.2383, -0.1200, 0.2620,
#> -0.0575, 0.0228])
#> Token: Token[1]: "two"
#> tensor([ 0.0282, -0.0786, -0.1236, 0.1756, -0.1199, 0.0964, -0.1327, 0.4449,
#> -0.0264, -0.1168])
#> Token: Token[2]: "three"
#> tensor([-0.0920, -0.0690, -0.1475, 0.2313, -0.0872, 0.0799, -0.0901, 0.4403,
#> -0.0103, -0.1494])
#> Token: Token[3]: "one"
#> tensor([-0.0535, -0.0368, -0.2851, -0.0381, -0.0486, 0.2383, -0.1200, 0.2620,
#> -0.0575, 0.0228])
Visialized Embeddings
Word Embeddings (GloVe)
GloVe embeddings are Pytorch vectors of dimensionality 100.
For English, Flair provides a few more options. Here, you can use ‘en-glove’ and ‘en-extvec’ with the WordEmbeddings class.
# Tokenize & Embed
# load the GloVe embeddings from Flair NLP via flaiR
WordEmbeddings <- flair_embeddings()$WordEmbeddings
embedding <- WordEmbeddings("glove")
# Tokenize the text
sentence <- Sentence("King Queen man woman Paris London apple orange Taiwan Dublin Bamberg")
# Embed the sentence text using the loaded model.
embedding$embed(sentence)
#> [[1]]
#> Sentence[11]: "King Queen man woman Paris London apple orange Taiwan Dublin Bamberg"
- The
sentence
is being embedded with the corresponding vector from the model, store to the list.
sen_list <- list()
for (i in seq_along(sentence$tokens)) {
# store the tensor vectors to numeric vectors
sen_list[[i]] <- as.vector(sentence$tokens[[i]]$embedding$numpy())
}
- Extract the name list to R vector
token_texts <- sapply(sentence$tokens, function(token) token$text)
- Form the dataframe.
sen_df <- do.call(rbind, lapply(sen_list, function(x) t(data.frame(x))))
sen_df <- as.data.frame(sen_df)
rownames(sen_df) <- token_texts
print(sen_df[,1:20])
#> V1 V2 V3 V4 V5 V6
#> King -0.3230700 -0.876160 0.219770 0.252680 0.2297600 0.73880
#> Queen -0.5004500 -0.708260 0.553880 0.673000 0.2248600 0.60281
#> man 0.3729300 0.385030 0.710860 -0.659110 -0.0010128 0.92715
#> woman 0.5936800 0.448250 0.593200 0.074134 0.1114100 1.27930
#> Paris 0.9260500 -0.228180 -0.255240 0.739970 0.5007200 0.26424
#> London 0.6055300 -0.050886 -0.154610 -0.123270 0.6627000 -0.28506
#> apple -0.5985000 -0.463210 0.130010 -0.019576 0.4603000 -0.30180
#> orange -0.1496900 0.164770 -0.355320 -0.719150 0.6213000 0.74140
#> Taiwan 0.0061832 0.117350 0.535380 0.787290 0.6427700 -0.56057
#> Dublin -0.4281400 -0.168970 0.035079 0.133170 0.4115600 1.03810
#> Bamberg 0.4854000 -0.296800 0.103520 -0.250310 0.4100900 0.45147
#> V7 V8 V9 V10 V11 V12
#> King -0.37954 -0.353070 -0.84369 -1.1113000 -0.302660 0.331780
#> Queen -0.26194 0.738720 -0.65383 -0.2160600 -0.338060 0.244980
#> man 0.27615 -0.056203 -0.24294 0.2463200 -0.184490 0.313980
#> woman 0.16656 0.240700 0.39045 0.3276600 -0.750340 0.350070
#> Paris 0.40056 0.561450 0.17908 0.0504640 0.024095 -0.064805
#> London -0.68844 0.491350 -0.68924 0.3892600 0.143590 -0.488020
#> apple 0.89770 -0.656340 0.66858 -0.4916400 0.037557 -0.050889
#> orange 0.68959 0.403710 -0.24239 0.1774000 -0.950790 -0.188870
#> Taiwan -0.35941 -0.157720 0.97407 -0.1026900 -0.852620 -0.058598
#> Dublin -0.32697 0.333970 -0.16726 -0.0034566 -0.361420 -0.067648
#> Bamberg -0.08002 -0.264430 -0.47231 0.0170920 0.036594 -0.483970
#> V13 V14 V15 V16 V17 V18
#> King -0.25113 0.30448 -0.077491 -0.8981500 0.092496 -1.140700
#> Queen -0.51497 0.85680 -0.371990 -0.5882400 0.306370 -0.306680
#> man 0.48983 0.09256 0.329580 0.1505600 0.573170 -0.185290
#> woman 0.76057 0.38067 0.175170 0.0317910 0.468490 -0.216530
#> Paris -0.25491 0.29661 -0.476020 0.2424400 -0.067045 -0.460290
#> London 0.15746 0.83178 -0.279230 0.0094755 -0.112070 -0.520990
#> apple 0.64510 -0.53882 -0.376500 -0.0431200 0.513840 0.177830
#> orange -0.02344 0.49681 0.081903 -0.3694400 1.225700 -0.119000
#> Taiwan 1.19080 0.19279 -0.266930 -0.7671900 0.681310 -0.240430
#> Dublin -0.45075 1.43470 -0.591370 -0.3136400 0.602490 0.145310
#> Bamberg -0.18393 0.68727 0.249500 0.2045100 0.517300 0.084214
#> V19 V20
#> King -0.583240 0.66869
#> Queen -0.218700 0.78369
#> man -0.522770 0.46191
#> woman -0.462820 0.39967
#> Paris -0.384060 -0.36540
#> London -0.371590 -0.37951
#> apple 0.285960 0.92063
#> orange 0.955710 -0.19501
#> Taiwan -0.086499 -0.18486
#> Dublin -0.351880 0.18191
#> Bamberg -0.115300 -0.53820
Dimension Reduction (PCA)
# Set the seed for reproducibility
set.seed(123)
# Execute PCA
pca_result <- prcomp(sen_df, center = TRUE, scale. = TRUE)
word_embeddings_matrix <- as.data.frame(pca_result$x[,1:3] )
rownames(word_embeddings_matrix) <- token_texts
word_embeddings_matrix
#> PC1 PC2 PC3
#> King -2.9120910 1.285200 -1.95053854
#> Queen -2.2413804 2.266714 -1.09020972
#> man -5.6381902 2.984461 3.55462010
#> woman -6.4891003 2.458607 3.56693660
#> Paris 3.0702212 5.039061 -2.65962020
#> London 5.3196216 4.368433 -2.60726627
#> apple 0.3362535 -8.679358 -0.44752722
#> orange -0.0485467 -4.404101 0.77151480
#> Taiwan -2.7993829 -4.149287 -6.33296039
#> Dublin 5.8994096 1.063291 -0.09271925
#> Bamberg 5.5031854 -2.233020 7.28777009
2D Plot
library(ggplot2)
glove_plot2D <- ggplot(word_embeddings_matrix, aes(x = PC1, y = PC2, color = PC3,
label = rownames(word_embeddings_matrix))) +
geom_point(size = 3) +
geom_text(vjust = 1.5, hjust = 0.5) +
scale_color_gradient(low = "blue", high = "red") +
theme_minimal() +
labs(title = "", x = "PC1", y = "PC2", color = "PC3")
# guides(color = "none")
glove_plot2D
Stack Embeddings Method (GloVe + Back/forwad FlairEmbeddings or More)
# Tokenize & Embed
# load WordEmbeddings and FlairEmbeddings
WordEmbeddings <- flair_embeddings()$WordEmbeddings
FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
StackedEmbeddings <- flair_embeddings()$StackedEmbeddings
# init standard GloVe embedding
glove_embedding = WordEmbeddings('glove')
# init Flair forward and backwards embeddings
flair_embedding_forward <- FlairEmbeddings('news-forward')
flair_embedding_backward <- FlairEmbeddings('news-backward')
embedding <- WordEmbeddings("glove")
# create a StackedEmbedding object that combines glove and forward/backward flair embeddings
stacked_embeddings <- StackedEmbeddings(c(glove_embedding,
flair_embedding_forward,
flair_embedding_backward))
Sentence <- flair_data()$Sentence
sentence <- Sentence("King Queen man woman Paris London apple orange Taiwan Dublin Bamberg")
# just embed a sentence using the StackedEmbedding as you would with any single embedding.
stacked_embeddings$embed(sentence)
# The `sentence` is being embedded with the corresponding vector from the model, store to the list.
sen_list <- list()
for (i in seq_along(sentence$tokens)) {
# store the tensor vectors to numeric vectors
sen_list[[i]] <- as.vector(sentence$tokens[[i]]$embedding$numpy())
}
# Extract the name list to R vector
token_texts <- sapply(sentence$tokens, function(token) token$text)
# Form the dataframe.
sen_df <- do.call(rbind, lapply(sen_list, function(x) t(data.frame(x))))
sen_df <- as.data.frame(sen_df)
rownames(sen_df) <- token_texts
print(sen_df[,1:20])
#> V1 V2 V3 V4 V5 V6
#> King -0.3230700 -0.876160 0.219770 0.252680 0.2297600 0.73880
#> Queen -0.5004500 -0.708260 0.553880 0.673000 0.2248600 0.60281
#> man 0.3729300 0.385030 0.710860 -0.659110 -0.0010128 0.92715
#> woman 0.5936800 0.448250 0.593200 0.074134 0.1114100 1.27930
#> Paris 0.9260500 -0.228180 -0.255240 0.739970 0.5007200 0.26424
#> London 0.6055300 -0.050886 -0.154610 -0.123270 0.6627000 -0.28506
#> apple -0.5985000 -0.463210 0.130010 -0.019576 0.4603000 -0.30180
#> orange -0.1496900 0.164770 -0.355320 -0.719150 0.6213000 0.74140
#> Taiwan 0.0061832 0.117350 0.535380 0.787290 0.6427700 -0.56057
#> Dublin -0.4281400 -0.168970 0.035079 0.133170 0.4115600 1.03810
#> Bamberg 0.4854000 -0.296800 0.103520 -0.250310 0.4100900 0.45147
#> V7 V8 V9 V10 V11 V12
#> King -0.37954 -0.353070 -0.84369 -1.1113000 -0.302660 0.331780
#> Queen -0.26194 0.738720 -0.65383 -0.2160600 -0.338060 0.244980
#> man 0.27615 -0.056203 -0.24294 0.2463200 -0.184490 0.313980
#> woman 0.16656 0.240700 0.39045 0.3276600 -0.750340 0.350070
#> Paris 0.40056 0.561450 0.17908 0.0504640 0.024095 -0.064805
#> London -0.68844 0.491350 -0.68924 0.3892600 0.143590 -0.488020
#> apple 0.89770 -0.656340 0.66858 -0.4916400 0.037557 -0.050889
#> orange 0.68959 0.403710 -0.24239 0.1774000 -0.950790 -0.188870
#> Taiwan -0.35941 -0.157720 0.97407 -0.1026900 -0.852620 -0.058598
#> Dublin -0.32697 0.333970 -0.16726 -0.0034566 -0.361420 -0.067648
#> Bamberg -0.08002 -0.264430 -0.47231 0.0170920 0.036594 -0.483970
#> V13 V14 V15 V16 V17 V18
#> King -0.25113 0.30448 -0.077491 -0.8981500 0.092496 -1.140700
#> Queen -0.51497 0.85680 -0.371990 -0.5882400 0.306370 -0.306680
#> man 0.48983 0.09256 0.329580 0.1505600 0.573170 -0.185290
#> woman 0.76057 0.38067 0.175170 0.0317910 0.468490 -0.216530
#> Paris -0.25491 0.29661 -0.476020 0.2424400 -0.067045 -0.460290
#> London 0.15746 0.83178 -0.279230 0.0094755 -0.112070 -0.520990
#> apple 0.64510 -0.53882 -0.376500 -0.0431200 0.513840 0.177830
#> orange -0.02344 0.49681 0.081903 -0.3694400 1.225700 -0.119000
#> Taiwan 1.19080 0.19279 -0.266930 -0.7671900 0.681310 -0.240430
#> Dublin -0.45075 1.43470 -0.591370 -0.3136400 0.602490 0.145310
#> Bamberg -0.18393 0.68727 0.249500 0.2045100 0.517300 0.084214
#> V19 V20
#> King -0.583240 0.66869
#> Queen -0.218700 0.78369
#> man -0.522770 0.46191
#> woman -0.462820 0.39967
#> Paris -0.384060 -0.36540
#> London -0.371590 -0.37951
#> apple 0.285960 0.92063
#> orange 0.955710 -0.19501
#> Taiwan -0.086499 -0.18486
#> Dublin -0.351880 0.18191
#> Bamberg -0.115300 -0.53820
# Dimension Reduction
# Set the seed for reproducibility
set.seed(123)
# Execute PCA
pca_result <- prcomp(sen_df, center = TRUE, scale. = TRUE)
word_embeddings_matrix <- as.data.frame(pca_result$x[,1:3] )
rownames(word_embeddings_matrix) <- token_texts
word_embeddings_matrix
#> PC1 PC2 PC3
#> King -8.607474 67.2291112 32.4862807
#> Queen 1.757707 12.0477210 -26.8302480
#> man 70.603191 -6.6184707 13.1651688
#> woman 22.532043 -8.1126267 -0.9073998
#> Paris -11.395619 -0.3051693 -17.5197067
#> London -8.709174 -2.7450626 -14.1780531
#> apple -8.739477 -15.7725211 -6.3796349
#> orange -25.178329 -38.8501308 51.4907636
#> Taiwan -9.132397 -5.0252091 -11.0918877
#> Dublin -10.925014 -3.3407329 -10.1367729
#> Bamberg -12.205457 1.4930909 -10.0985100
# 2D Plot
library(ggplot2)
stacked_plot2D <- ggplot(word_embeddings_matrix, aes(x = PC1, y = PC2, color = PC3,
label = rownames(word_embeddings_matrix))) +
geom_point(size = 2) +
geom_text(vjust = 1.5, hjust = 0.5) +
scale_color_gradient(low = "blue", high = "red") +
theme_minimal() +
labs(title = "", x = "PC1", y = "PC2", color = "PC3")
# guides(color = "none")
stacked_plot2D
Transformer Embeddings (BERT or More)
library(flaiR)
# Load Sentence and BERT model
Sentence <- flair_data()$Sentence
TransformerWordEmbeddings <- flair_embeddings()$TransformerWordEmbeddings("bert-base-uncased")
sentence <- Sentence("King Queen man woman Paris London apple orange Taiwan Dublin Bamberg")
TransformerWordEmbeddings$embed(sentence)
#> [[1]]
#> Sentence[11]: "King Queen man woman Paris London apple orange Taiwan Dublin Bamberg"
# The `sentence` is being embedded with the corresponding vector from the model, store to the list.
sen_list <- list()
for (i in seq_along(sentence$tokens)) {
# store the tensor vectors to numeric vectors
sen_list[[i]] <- as.vector(sentence$tokens[[i]]$embedding$numpy())
}
# Extract the name list to R vector
token_texts <- sapply(sentence$tokens, function(token) token$text)
# Format the dataframe.
sen_df <- do.call(rbind, lapply(sen_list, function(x) t(data.frame(x))))
sen_df <- as.data.frame(sen_df)
rownames(sen_df) <- token_texts
print(sen_df[,1:20])
#> V1 V2 V3 V4 V5
#> King 0.20243725 -0.302875251 0.88937163 -0.39685369 -0.1486676
#> Queen 0.38636127 -0.136571497 1.07043457 -0.39234018 0.3492773
#> man -0.11609044 -0.278030753 1.11280990 -0.26518247 0.5489236
#> woman -0.34034440 0.002125926 0.70898271 -0.01405379 0.1248749
#> Paris -0.11196213 0.049277060 0.57255691 -0.17925967 -0.1082409
#> London -0.24880084 0.129327983 0.16446528 -0.38713518 -0.3233015
#> apple 0.01825429 0.289205372 0.08635567 -0.14463882 0.4227849
#> orange -0.08206877 -0.346834421 0.14041321 -0.01703753 0.9596846
#> Taiwan -0.21529624 0.164859354 0.55788797 0.08166986 -0.3301849
#> Dublin -0.26578656 0.304162681 0.24737860 -0.05140827 -0.6028203
#> Bamberg -0.02511016 -0.568264186 0.43023247 -0.07210951 0.1909387
#> V6 V7 V8 V9 V10
#> King 0.008254345 -0.03680899 0.06456959 -0.368780434 -0.32244465
#> Queen 0.004886352 0.44847178 0.65026569 0.058996066 -0.47662011
#> man 0.342995614 0.20067003 0.30138496 0.188928455 -0.53515178
#> woman 0.437581152 0.50878483 0.29988331 -0.082988761 -0.02024770
#> Paris 0.158250496 -0.25374153 0.23834446 -0.014040919 0.05501739
#> London -0.247915953 -0.51886940 0.10059015 0.064994000 -0.13463272
#> apple 0.303519875 -0.24322695 0.55958885 0.132094145 -0.30050987
#> orange 0.240883887 -0.36664891 0.37507194 -0.170591414 -0.57967031
#> Taiwan -0.088986777 -0.17876987 0.29112476 0.422848225 0.22081339
#> Dublin -0.238457471 0.05760814 0.11389303 0.008800384 0.24732894
#> Bamberg -0.657174826 0.12236608 0.04169786 -1.401770711 0.27092379
#> V11 V12 V13 V14 V15
#> King -0.21798825 0.094209142 0.28516501 0.15710688 -0.35424700
#> Queen -0.38251543 -0.486830354 -0.02046072 0.58473915 -0.58738035
#> man -0.33577001 -0.177165955 -0.09695894 0.24606071 -0.45691946
#> woman -0.18178533 0.001003223 0.26221192 0.06231856 -0.38738111
#> Paris -0.22901714 0.457912028 0.04404202 -0.25944969 0.06393644
#> London -0.18075228 0.514925241 0.12762237 -0.24772680 -0.05417314
#> apple -0.15327577 0.442788988 0.21187572 -0.26953286 0.29568702
#> orange -0.10221374 -0.048551828 0.07183206 0.26971415 -0.12249112
#> Taiwan -0.36020637 0.418745756 0.25126299 -0.63927531 0.20376265
#> Dublin -0.40517145 0.282728165 0.16409895 -0.12579527 0.14734903
#> Bamberg -0.01255565 -0.380868316 0.24334061 0.05542289 -1.02990639
#> V16 V17 V18 V19 V20
#> King -0.10926335 -0.04176838 0.08259771 -0.03663497 -0.10094819
#> Queen 0.16076291 -0.14089347 -0.03677157 -0.13579193 -0.19312018
#> man -0.39715117 -0.21910432 0.14614260 -0.38266426 -0.35972801
#> woman -0.00680415 -0.20422488 0.07532670 0.10188285 -0.04957973
#> Paris 0.29756585 -0.06224668 -0.12700970 0.46103561 0.43724096
#> London -0.02949974 0.09688052 -0.07327853 0.96583605 0.40061006
#> apple -0.08441193 0.05031339 0.22658241 0.33377019 0.24120221
#> orange 0.22620890 0.15824266 0.48752186 -0.78691489 -0.11886403
#> Taiwan 0.07144606 0.28152949 -0.17945638 0.49592388 0.26298687
#> Dublin 0.16665313 0.41901654 -0.24815717 0.71733290 0.39634189
#> Bamberg 0.92171913 0.52353865 -0.23750052 -0.27257800 0.32791424
# Dimension Reduction using PCA
set.seed(123) # Set the seed for reproducibility
pca_result <- prcomp(sen_df, center = TRUE, scale. = TRUE)
word_embeddings_matrix <- as.data.frame(pca_result$x[,1:3] )
rownames(word_embeddings_matrix) <- token_texts
# 2D Plot
library(ggplot2)
bert_plot2D <- ggplot(word_embeddings_matrix, aes(x = PC1, y = PC2, color = PC3,
label = rownames(word_embeddings_matrix))) +
geom_point(size = 2) +
geom_text(vjust = 1.5, hjust = 0.5) +
scale_color_gradient(low = "blue", high = "red") +
theme_minimal() +
labs(title = "", x = "PC1", y = "PC2", color = "PC3")
# guides(color = "none")
stacked_plot2D
Training a Binary Classifier
In this section, we’ll train a sentiment analysis model that can categorize text as either positive or negative. This case study is adapted from pages 116 to 130 of Tadej Magajna’s book, ‘Natural Language Processing with Flair’. The process for training text classifiers in Flair mirrors the process followed for sequence labeling models. Specifically, the steps to train text classifiers are:
- Load a tagged corpus and compute the label dictionary map.
- Prepare the document embeddings.
- Initialize the
TextClassifier
class. - Train the model.
Loading a Tagged Corpus
Training text classification models requires a set of text documents (typically, sentences or paragraphs) where each document is associated with one or more classification labels. To train our sentiment analysis text classification model, we will be using the famous Internet Movie Database (IMDb) dataset, which contains 50,000 movie reviews from IMDB, where each review is labeled as either positive or negative. References to this dataset are already baked into Flair, so loading the dataset couldn’t be easier:
library(flaiR)
# load IMDB from flair_datasets module
Corpus <- flair_data()$Corpus
IMDB <- flair_datasets()$IMDB
# downsize to 0.05
corpus = IMDB()
#> 2024-11-28 20:56:49,027 Reading data from /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced
#> 2024-11-28 20:56:49,027 Train: /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced/train.txt
#> 2024-11-28 20:56:49,027 Dev: None
#> 2024-11-28 20:56:49,027 Test: None
#> 2024-11-28 20:56:49,571 No test split found. Using 10% (i.e. 5000 samples) of the train split as test data
#> 2024-11-28 20:56:49,582 No dev split found. Using 10% (i.e. 4500 samples) of the train split as dev data
#> 2024-11-28 20:56:49,582 Initialized corpus /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced (label type name is 'sentiment')
corpus$downsample(0.05)
#> <flair.datasets.document_classification.IMDB object at 0x3926e1290>
Print the sizes in the corpus object as follows - test: %d | train: %d | dev: %d”
test_size <- length(corpus$test)
train_size <- length(corpus$train)
dev_size <- length(corpus$dev)
output <- sprintf("Corpus object sizes - Test: %d | Train: %d | Dev: %d", test_size, train_size, dev_size)
print(output)
#> [1] "Corpus object sizes - Test: 250 | Train: 2025 | Dev: 225"
lbl_type = 'sentiment'
label_dict = corpus$make_label_dictionary(label_type=lbl_type)
#> 2024-11-28 20:56:49,675 Computing label dictionary. Progress:
#> 2024-11-28 20:56:53,006 Dictionary created for label 'sentiment' with 2 values: POSITIVE (seen 1014 times), NEGATIVE (seen 1011 times)
Loading the Embeddings
flaiR covers all the different types of document embeddings that we
can use. Here, we simply use DocumentPoolEmbeddings
. They
require no training prior to training the classification model
itself:
DocumentPoolEmbeddings <- flair_embeddings()$DocumentPoolEmbeddings
WordEmbeddings <- flair_embeddings()$WordEmbeddings
glove = WordEmbeddings('glove')
document_embeddings = DocumentPoolEmbeddings(glove)
Initializing the TextClassifier
# initiate TextClassifier
TextClassifier <- flair_models()$TextClassifier
classifier <- TextClassifier(document_embeddings,
label_dictionary = label_dict,
label_type = lbl_type)
$to
allows you to set the device to use CPU, GPU, or
specific MPS devices on Mac (such as mps:0, mps:1, mps:2).
classifier$to(flair_device("mps"))
TextClassifier(
(embeddings): DocumentPoolEmbeddings(
fine_tune_mode=none, pooling=mean
(embeddings): StackedEmbeddings(
(list_embedding_0): WordEmbeddings(
'glove'
(embedding): Embedding(400001, 100)
)
)
)
(decoder): Linear(in_features=100, out_features=3, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(locked_dropout): LockedDropout(p=0.0)
(word_dropout): WordDropout(p=0.0)
(loss_function): CrossEntropyLoss()
)
Training the Model
Training the text classifier model involves two simple steps:
- Defining the model trainer class by passing in the classifier model and the corpus
- Setting off the training process passing in the required training hyper-parameters.
It is worth noting that the ‘L’ in numbers like 32L and 5L is used in R to denote that the number is an integer. Without the ‘L’ suffix, numbers in R are treated as numeric, which are by default double-precision floating-point numbers. In contrast, Python determines the type based on the value of the number itself. Whole numbers (e.g., 5 or 32) are of type int, while numbers with decimal points (e.g., 5.0) are of type float. Floating-point numbers in both languages are representations of real numbers but can have some approximation due to the way they are stored in memory.
# initiate ModelTrainer
ModelTrainer <- flair_trainers()$ModelTrainer
# fit the model
trainer <- ModelTrainer(classifier, corpus)
# start to train
# note: the 'L' in 32L is used in R to denote that the number is an integer.
trainer$train('classifier',
learning_rate=0.1,
mini_batch_size=32L,
# specifies how embeddings are stored in RAM, ie."cpu", "cuda", "gpu", "mps".
# embeddings_storage_mode = "mps",
max_epochs=10L)
#> 2024-11-28 20:56:55,154 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Model: "TextClassifier(
#> (embeddings): DocumentPoolEmbeddings(
#> fine_tune_mode=none, pooling=mean
#> (embeddings): StackedEmbeddings(
#> (list_embedding_0): WordEmbeddings(
#> 'glove'
#> (embedding): Embedding(400001, 100)
#> )
#> )
#> )
#> (decoder): Linear(in_features=100, out_features=2, bias=True)
#> (dropout): Dropout(p=0.0, inplace=False)
#> (locked_dropout): LockedDropout(p=0.0)
#> (word_dropout): WordDropout(p=0.0)
#> (loss_function): CrossEntropyLoss()
#> (weights): None
#> (weight_tensor) None
#> )"
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Corpus: 2025 train + 225 dev + 250 test sentences
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Train: 2025 sentences
#> 2024-11-28 20:56:55,155 (train_with_dev=False, train_with_test=False)
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Training Params:
#> 2024-11-28 20:56:55,155 - learning_rate: "0.1"
#> 2024-11-28 20:56:55,155 - mini_batch_size: "32"
#> 2024-11-28 20:56:55,155 - max_epochs: "10"
#> 2024-11-28 20:56:55,155 - shuffle: "True"
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Plugins:
#> 2024-11-28 20:56:55,155 - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Final evaluation on model from best epoch (best-model.pt)
#> 2024-11-28 20:56:55,155 - metric: "('micro avg', 'f1-score')"
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,155 Computation:
#> 2024-11-28 20:56:55,155 - compute on device: cpu
#> 2024-11-28 20:56:55,155 - embedding storage: cpu
#> 2024-11-28 20:56:55,155 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,156 Model training base path: "classifier"
#> 2024-11-28 20:56:55,156 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,156 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:56:55,927 epoch 1 - iter 6/64 - loss 0.91174122 - time (sec): 0.77 - samples/sec: 248.83 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:56:56,866 epoch 1 - iter 12/64 - loss 0.96861458 - time (sec): 1.71 - samples/sec: 224.55 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:56:57,578 epoch 1 - iter 18/64 - loss 0.97410540 - time (sec): 2.42 - samples/sec: 237.80 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:56:58,235 epoch 1 - iter 24/64 - loss 0.98070454 - time (sec): 3.08 - samples/sec: 249.39 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:56:59,097 epoch 1 - iter 30/64 - loss 0.96716065 - time (sec): 3.94 - samples/sec: 243.56 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:56:59,833 epoch 1 - iter 36/64 - loss 0.96751174 - time (sec): 4.68 - samples/sec: 246.32 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:00,522 epoch 1 - iter 42/64 - loss 0.95211656 - time (sec): 5.37 - samples/sec: 250.46 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:01,169 epoch 1 - iter 48/64 - loss 0.95586962 - time (sec): 6.01 - samples/sec: 255.42 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:02,084 epoch 1 - iter 54/64 - loss 0.95892662 - time (sec): 6.93 - samples/sec: 249.41 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:02,777 epoch 1 - iter 60/64 - loss 0.94931542 - time (sec): 7.62 - samples/sec: 251.93 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:03,256 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:03,256 EPOCH 1 done: loss 0.9538 - lr: 0.100000
#> 2024-11-28 20:57:04,178 DEV : loss 0.6814143061637878 - f1-score (micro avg) 0.5511
#> 2024-11-28 20:57:04,550 - 0 epochs without improvement
#> 2024-11-28 20:57:04,552 saving best model
#> 2024-11-28 20:57:04,909 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:06,008 epoch 2 - iter 6/64 - loss 0.88309590 - time (sec): 1.10 - samples/sec: 174.91 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:06,695 epoch 2 - iter 12/64 - loss 0.91989129 - time (sec): 1.79 - samples/sec: 215.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:07,430 epoch 2 - iter 18/64 - loss 0.90330828 - time (sec): 2.52 - samples/sec: 228.54 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:08,173 epoch 2 - iter 24/64 - loss 0.90877422 - time (sec): 3.26 - samples/sec: 235.37 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:08,874 epoch 2 - iter 30/64 - loss 0.90456812 - time (sec): 3.96 - samples/sec: 242.15 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:09,740 epoch 2 - iter 36/64 - loss 0.90967931 - time (sec): 4.83 - samples/sec: 238.51 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:10,449 epoch 2 - iter 42/64 - loss 0.91712019 - time (sec): 5.54 - samples/sec: 242.64 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:11,172 epoch 2 - iter 48/64 - loss 0.90972924 - time (sec): 6.26 - samples/sec: 245.28 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:11,881 epoch 2 - iter 54/64 - loss 0.89838140 - time (sec): 6.97 - samples/sec: 247.88 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:12,583 epoch 2 - iter 60/64 - loss 0.89817547 - time (sec): 7.67 - samples/sec: 250.22 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:13,053 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:13,053 EPOCH 2 done: loss 0.8927 - lr: 0.100000
#> 2024-11-28 20:57:14,206 DEV : loss 0.6868590712547302 - f1-score (micro avg) 0.5556
#> 2024-11-28 20:57:14,571 - 0 epochs without improvement
#> 2024-11-28 20:57:14,573 saving best model
#> 2024-11-28 20:57:14,927 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:15,729 epoch 3 - iter 6/64 - loss 0.92060424 - time (sec): 0.80 - samples/sec: 239.52 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:16,475 epoch 3 - iter 12/64 - loss 0.92075667 - time (sec): 1.55 - samples/sec: 248.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:17,347 epoch 3 - iter 18/64 - loss 0.91408771 - time (sec): 2.42 - samples/sec: 238.11 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:18,046 epoch 3 - iter 24/64 - loss 0.90718419 - time (sec): 3.12 - samples/sec: 246.32 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:18,772 epoch 3 - iter 30/64 - loss 0.90797503 - time (sec): 3.84 - samples/sec: 249.74 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:19,538 epoch 3 - iter 36/64 - loss 0.89226388 - time (sec): 4.61 - samples/sec: 249.86 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:20,272 epoch 3 - iter 42/64 - loss 0.88858810 - time (sec): 5.34 - samples/sec: 251.48 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:21,180 epoch 3 - iter 48/64 - loss 0.88652510 - time (sec): 6.25 - samples/sec: 245.68 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:21,890 epoch 3 - iter 54/64 - loss 0.88052210 - time (sec): 6.96 - samples/sec: 248.19 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:22,564 epoch 3 - iter 60/64 - loss 0.88498759 - time (sec): 7.64 - samples/sec: 251.41 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:23,057 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:23,057 EPOCH 3 done: loss 0.8869 - lr: 0.100000
#> 2024-11-28 20:57:23,965 DEV : loss 0.695308268070221 - f1-score (micro avg) 0.5644
#> 2024-11-28 20:57:24,326 - 0 epochs without improvement
#> 2024-11-28 20:57:24,330 saving best model
#> 2024-11-28 20:57:24,637 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:25,669 epoch 4 - iter 6/64 - loss 0.84620678 - time (sec): 1.03 - samples/sec: 186.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:26,389 epoch 4 - iter 12/64 - loss 0.85321340 - time (sec): 1.75 - samples/sec: 219.26 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:27,146 epoch 4 - iter 18/64 - loss 0.85991294 - time (sec): 2.51 - samples/sec: 229.63 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:27,825 epoch 4 - iter 24/64 - loss 0.88514151 - time (sec): 3.19 - samples/sec: 240.95 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:28,702 epoch 4 - iter 30/64 - loss 0.87729849 - time (sec): 4.07 - samples/sec: 236.16 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:29,443 epoch 4 - iter 36/64 - loss 0.86967195 - time (sec): 4.81 - samples/sec: 239.71 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:30,141 epoch 4 - iter 42/64 - loss 0.86305476 - time (sec): 5.50 - samples/sec: 244.20 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:30,879 epoch 4 - iter 48/64 - loss 0.87007949 - time (sec): 6.24 - samples/sec: 246.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:31,787 epoch 4 - iter 54/64 - loss 0.86743812 - time (sec): 7.15 - samples/sec: 241.71 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:32,455 epoch 4 - iter 60/64 - loss 0.85504824 - time (sec): 7.82 - samples/sec: 245.61 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:32,908 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:32,909 EPOCH 4 done: loss 0.8562 - lr: 0.100000
#> 2024-11-28 20:57:33,826 DEV : loss 0.6955223679542542 - f1-score (micro avg) 0.5733
#> 2024-11-28 20:57:34,194 - 0 epochs without improvement
#> 2024-11-28 20:57:34,195 saving best model
#> 2024-11-28 20:57:34,455 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:35,499 epoch 5 - iter 6/64 - loss 0.79712009 - time (sec): 1.04 - samples/sec: 184.00 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:36,197 epoch 5 - iter 12/64 - loss 0.83594465 - time (sec): 1.74 - samples/sec: 220.46 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:36,916 epoch 5 - iter 18/64 - loss 0.82634424 - time (sec): 2.46 - samples/sec: 234.08 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:37,866 epoch 5 - iter 24/64 - loss 0.84943491 - time (sec): 3.41 - samples/sec: 225.21 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:38,626 epoch 5 - iter 30/64 - loss 0.83855354 - time (sec): 4.17 - samples/sec: 230.18 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:39,352 epoch 5 - iter 36/64 - loss 0.83839231 - time (sec): 4.90 - samples/sec: 235.24 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:40,056 epoch 5 - iter 42/64 - loss 0.84415555 - time (sec): 5.60 - samples/sec: 239.97 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:40,795 epoch 5 - iter 48/64 - loss 0.84729431 - time (sec): 6.34 - samples/sec: 242.27 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:41,696 epoch 5 - iter 54/64 - loss 0.84063320 - time (sec): 7.24 - samples/sec: 238.66 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:42,439 epoch 5 - iter 60/64 - loss 0.84694536 - time (sec): 7.98 - samples/sec: 240.51 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:42,696 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:42,696 EPOCH 5 done: loss 0.8443 - lr: 0.100000
#> 2024-11-28 20:57:43,822 DEV : loss 0.6921943426132202 - f1-score (micro avg) 0.5822
#> 2024-11-28 20:57:44,187 - 0 epochs without improvement
#> 2024-11-28 20:57:44,189 saving best model
#> 2024-11-28 20:57:44,457 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:45,507 epoch 6 - iter 6/64 - loss 0.82745030 - time (sec): 1.05 - samples/sec: 182.86 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:46,246 epoch 6 - iter 12/64 - loss 0.83163678 - time (sec): 1.79 - samples/sec: 214.64 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:46,906 epoch 6 - iter 18/64 - loss 0.83676998 - time (sec): 2.45 - samples/sec: 235.19 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:47,655 epoch 6 - iter 24/64 - loss 0.83001364 - time (sec): 3.20 - samples/sec: 240.17 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:48,541 epoch 6 - iter 30/64 - loss 0.83004258 - time (sec): 4.08 - samples/sec: 235.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:49,177 epoch 6 - iter 36/64 - loss 0.83843245 - time (sec): 4.72 - samples/sec: 244.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:49,934 epoch 6 - iter 42/64 - loss 0.82743987 - time (sec): 5.48 - samples/sec: 245.38 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:50,643 epoch 6 - iter 48/64 - loss 0.82295681 - time (sec): 6.19 - samples/sec: 248.30 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:51,493 epoch 6 - iter 54/64 - loss 0.82645491 - time (sec): 7.04 - samples/sec: 245.60 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:52,206 epoch 6 - iter 60/64 - loss 0.82694597 - time (sec): 7.75 - samples/sec: 247.78 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:52,676 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:52,676 EPOCH 6 done: loss 0.8203 - lr: 0.100000
#> 2024-11-28 20:57:53,602 DEV : loss 1.2241613864898682 - f1-score (micro avg) 0.4533
#> 2024-11-28 20:57:53,990 - 1 epochs without improvement
#> 2024-11-28 20:57:53,992 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:57:55,041 epoch 7 - iter 6/64 - loss 0.99479450 - time (sec): 1.05 - samples/sec: 183.18 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:55,768 epoch 7 - iter 12/64 - loss 0.84938438 - time (sec): 1.78 - samples/sec: 216.28 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:56,518 epoch 7 - iter 18/64 - loss 0.86148151 - time (sec): 2.53 - samples/sec: 228.03 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:57,234 epoch 7 - iter 24/64 - loss 0.88028894 - time (sec): 3.24 - samples/sec: 236.96 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:57,911 epoch 7 - iter 30/64 - loss 0.85370332 - time (sec): 3.92 - samples/sec: 244.96 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:58,851 epoch 7 - iter 36/64 - loss 0.84295499 - time (sec): 4.86 - samples/sec: 237.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:57:59,577 epoch 7 - iter 42/64 - loss 0.83776256 - time (sec): 5.58 - samples/sec: 240.66 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:00,277 epoch 7 - iter 48/64 - loss 0.84896051 - time (sec): 6.28 - samples/sec: 244.41 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:00,963 epoch 7 - iter 54/64 - loss 0.82845668 - time (sec): 6.97 - samples/sec: 247.88 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:01,839 epoch 7 - iter 60/64 - loss 0.82379258 - time (sec): 7.85 - samples/sec: 244.68 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:02,143 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:02,143 EPOCH 7 done: loss 0.8225 - lr: 0.100000
#> 2024-11-28 20:58:03,250 DEV : loss 0.8091477751731873 - f1-score (micro avg) 0.5556
#> 2024-11-28 20:58:03,615 - 2 epochs without improvement
#> 2024-11-28 20:58:03,617 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:04,429 epoch 8 - iter 6/64 - loss 0.76075288 - time (sec): 0.81 - samples/sec: 236.68 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:05,355 epoch 8 - iter 12/64 - loss 0.78303260 - time (sec): 1.74 - samples/sec: 220.96 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:06,110 epoch 8 - iter 18/64 - loss 0.77066376 - time (sec): 2.49 - samples/sec: 231.12 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:06,810 epoch 8 - iter 24/64 - loss 0.75232087 - time (sec): 3.19 - samples/sec: 240.59 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:07,515 epoch 8 - iter 30/64 - loss 0.77737618 - time (sec): 3.90 - samples/sec: 246.33 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:08,215 epoch 8 - iter 36/64 - loss 0.77567254 - time (sec): 4.60 - samples/sec: 250.58 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:09,102 epoch 8 - iter 42/64 - loss 0.78603464 - time (sec): 5.48 - samples/sec: 245.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:09,787 epoch 8 - iter 48/64 - loss 0.78363027 - time (sec): 6.17 - samples/sec: 248.97 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:10,519 epoch 8 - iter 54/64 - loss 0.78447807 - time (sec): 6.90 - samples/sec: 250.39 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:11,180 epoch 8 - iter 60/64 - loss 0.78428696 - time (sec): 7.56 - samples/sec: 253.87 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:11,666 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:11,666 EPOCH 8 done: loss 0.7806 - lr: 0.100000
#> 2024-11-28 20:58:12,605 DEV : loss 0.619835615158081 - f1-score (micro avg) 0.6489
#> 2024-11-28 20:58:13,234 - 0 epochs without improvement
#> 2024-11-28 20:58:13,237 saving best model
#> 2024-11-28 20:58:13,502 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:14,192 epoch 9 - iter 6/64 - loss 0.81246133 - time (sec): 0.69 - samples/sec: 278.03 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:14,938 epoch 9 - iter 12/64 - loss 0.76851736 - time (sec): 1.44 - samples/sec: 267.32 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:15,672 epoch 9 - iter 18/64 - loss 0.76903655 - time (sec): 2.17 - samples/sec: 265.44 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:16,594 epoch 9 - iter 24/64 - loss 0.77481109 - time (sec): 3.09 - samples/sec: 248.36 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:17,323 epoch 9 - iter 30/64 - loss 0.77168306 - time (sec): 3.82 - samples/sec: 251.24 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:18,056 epoch 9 - iter 36/64 - loss 0.77215956 - time (sec): 4.55 - samples/sec: 252.97 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:18,722 epoch 9 - iter 42/64 - loss 0.76924835 - time (sec): 5.22 - samples/sec: 257.45 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:19,480 epoch 9 - iter 48/64 - loss 0.75900491 - time (sec): 5.98 - samples/sec: 256.92 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:20,311 epoch 9 - iter 54/64 - loss 0.77255435 - time (sec): 6.81 - samples/sec: 253.76 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:21,080 epoch 9 - iter 60/64 - loss 0.77476379 - time (sec): 7.58 - samples/sec: 253.36 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:21,349 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:21,349 EPOCH 9 done: loss 0.7797 - lr: 0.100000
#> 2024-11-28 20:58:22,475 DEV : loss 0.7148910760879517 - f1-score (micro avg) 0.5467
#> 2024-11-28 20:58:22,841 - 1 epochs without improvement
#> 2024-11-28 20:58:22,843 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:23,866 epoch 10 - iter 6/64 - loss 0.78717700 - time (sec): 1.02 - samples/sec: 187.79 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:24,598 epoch 10 - iter 12/64 - loss 0.74860297 - time (sec): 1.75 - samples/sec: 218.87 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:25,353 epoch 10 - iter 18/64 - loss 0.73957860 - time (sec): 2.51 - samples/sec: 229.53 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:26,016 epoch 10 - iter 24/64 - loss 0.76316903 - time (sec): 3.17 - samples/sec: 242.05 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:26,962 epoch 10 - iter 30/64 - loss 0.77254111 - time (sec): 4.12 - samples/sec: 233.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:27,723 epoch 10 - iter 36/64 - loss 0.77815889 - time (sec): 4.88 - samples/sec: 236.10 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:28,469 epoch 10 - iter 42/64 - loss 0.76594733 - time (sec): 5.63 - samples/sec: 238.90 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:29,162 epoch 10 - iter 48/64 - loss 0.76528511 - time (sec): 6.32 - samples/sec: 243.09 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:30,074 epoch 10 - iter 54/64 - loss 0.76086970 - time (sec): 7.23 - samples/sec: 238.98 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:30,753 epoch 10 - iter 60/64 - loss 0.76908471 - time (sec): 7.91 - samples/sec: 242.75 - lr: 0.100000 - momentum: 0.000000
#> 2024-11-28 20:58:31,039 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:31,039 EPOCH 10 done: loss 0.7660 - lr: 0.100000
#> 2024-11-28 20:58:32,218 DEV : loss 0.6605172157287598 - f1-score (micro avg) 0.5956
#> 2024-11-28 20:58:32,644 - 2 epochs without improvement
#> 2024-11-28 20:58:32,909 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:32,909 Loading model from best epoch ...
#> 2024-11-28 20:58:34,348
#> Results:
#> - F-score (micro) 0.628
#> - F-score (macro) 0.589
#> - Accuracy 0.628
#>
#> By class:
#> precision recall f1-score support
#>
#> NEGATIVE 0.5680 0.9669 0.7156 121
#> POSITIVE 0.9091 0.3101 0.4624 129
#>
#> accuracy 0.6280 250
#> macro avg 0.7385 0.6385 0.5890 250
#> weighted avg 0.7440 0.6280 0.5850 250
#>
#> 2024-11-28 20:58:34,348 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.628
Loading and Using the Classifiers
After training the text classification model, the resulting classifier will already be stored in memory as part of the classifier variable. It is possible, however, that your Python session exited after training. If so, you’ll need to load the model into memory with the following:
TextClassifier <- flair_models()$TextClassifier
classifier <- TextClassifier$load('classifier/best-model.pt')
We import the Sentence object. Now, we can generate predictions on some example text inputs.
Sentence <- flair_data()$Sentence
sentence <- Sentence("great")
classifier$predict(sentence)
print(sentence$labels)
#> [[1]]
#> 'Sentence[1]: "great"'/'POSITIVE' (1.0)
sentence <- Sentence("sad")
classifier$predict(sentence)
print(sentence$labels)
#> [[1]]
#> 'Sentence[1]: "sad"'/'NEGATIVE' (0.6436)
Training RNNs
Here, we train a sentiment analysis model to categorize text. In this case, we also include a pipeline that implements the use of Recurrent Neural Networks (RNN). This makes them particularly effective for tasks involving sequential data. This section also show you how to implement one of most powerful features in flaiR, stacked embeddings. You can stack multiple embeddings with different layers and let the classifier learn from different types of features. In Flair NLP, and with the flaiR package, it’s very easy to accomplish this task.
Import Necessary Modules
library(flaiR)
WordEmbeddings <- flair_embeddings()$WordEmbeddings
FlairEmbeddings <- flair_embeddings()$FlairEmbeddings
DocumentRNNEmbeddings <- flair_embeddings()$DocumentRNNEmbeddings
TextClassifier <- flair_models()$TextClassifier
ModelTrainer <- flair_trainers()$ModelTrainer
Get the IMDB Corpus
The IMDB movie review dataset is used here, which is a commonly
utilized dataset for sentiment analysis. $downsample(0.1)
method means only 10% of the dataset is used, allowing for a faster
demonstration.
# load the IMDB file and downsize it to 0.1
IMDB <- flair_datasets()$IMDB
corpus <- IMDB()$downsample(0.1)
#> 2024-11-28 20:58:34,764 Reading data from /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced
#> 2024-11-28 20:58:34,764 Train: /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced/train.txt
#> 2024-11-28 20:58:34,764 Dev: None
#> 2024-11-28 20:58:34,764 Test: None
#> 2024-11-28 20:58:35,311 No test split found. Using 10% (i.e. 5000 samples) of the train split as test data
#> 2024-11-28 20:58:35,324 No dev split found. Using 10% (i.e. 4500 samples) of the train split as dev data
#> 2024-11-28 20:58:35,324 Initialized corpus /Users/yenchiehliao/.flair/datasets/imdb_v4-rebalanced (label type name is 'sentiment')
# create the label dictionary
lbl_type <- 'sentiment'
label_dict <- corpus$make_label_dictionary(label_type=lbl_type)
#> 2024-11-28 20:58:35,339 Computing label dictionary. Progress:
#> 2024-11-28 20:58:42,236 Dictionary created for label 'sentiment' with 2 values: POSITIVE (seen 2056 times), NEGATIVE (seen 1994 times)
Stacked Embeddings
This is one of Flair’s most powerful features: it allows for the integration of embeddings to enable the model to learn from more sparse features. Three types of embeddings are utilized here: GloVe embeddings, and two types of Flair embeddings (forward and backward). Word embeddings are used to convert words into vectors.
# make a list of word embeddings
word_embeddings <- list(WordEmbeddings('glove'),
FlairEmbeddings('news-forward-fast'),
FlairEmbeddings('news-backward-fast'))
# initialize the document embeddings
document_embeddings <- DocumentRNNEmbeddings(word_embeddings,
hidden_size = 512L,
reproject_words = TRUE,
reproject_words_dimension = 256L)
# create a Text Classifier with the embeddings and label dictionary
classifier <- TextClassifier(document_embeddings,
label_dictionary=label_dict, label_type='class')
# initialize the text classifier trainer with our corpus
trainer <- ModelTrainer(classifier, corpus)
Start the Training
For the sake of this example, setting max_epochs to 5. You might want to increase this for better performance.
It is worth noting that the learning rate is a parameter that
determines the step size at each iteration while moving towards a
minimum of the loss function. A smaller learning rate could slow down
the learning process, but it could lead to more precise convergence.
mini_batch_size
determines the number of samples that will
be used to compute the gradient at each step. The ‘L’ in 32L is used in
R to denote that the number is an integer.
patience
(aka early stop) is a hyper-parameter used in
conjunction with early stopping to avoid overfitting. It determines the
number of epochs the training process will tolerate without improvements
before stopping the training. Setting max_epochs to 5 means the
algorithm will make five passes through the dataset.
# note: the 'L' in 32L is used in R to denote that the number is an integer.
trainer$train('models/sentiment',
learning_rate=0.1,
mini_batch_size=32L,
patience=5L,
max_epochs=5L)
To Apply the Trained Model for Prediction
sentence <- "This movie was really exciting!"
classifier$predict(sentence)
print(sentence.labels)
Finetune Transformers
We use data from The Temporal Focus of Campaign Communication
(2020 JOP) as an example. Let’s assume we receive the data for
training from different times. First, suppose you have a dataset of 1000
entries called cc_muller_old
. On another day, with the help
of nice friends, you receive another set of data, adding 2000 entries in
a dataset called cc_muller_new
. Both subsets are from
data(cc_muller)
. We will show how to fine-tune a
transformer model with cc_muller_old
, and then continue
with another round of fine-tuning using cc_muller_new
.
Fine-tuning a Transformers Model
Step 1 Load Necessary Modules from Flair
Load necessary classes from flair
package.
# Sentence is a class for holding a text sentence
Sentence <- flair_data()$Sentence
# Corpus is a class for text corpora
Corpus <- flair_data()$Corpus
# TransformerDocumentEmbeddings is a class for loading transformer
TransformerDocumentEmbeddings <- flair_embeddings()$TransformerDocumentEmbeddings
# TextClassifier is a class for text classification
TextClassifier <- flair_models()$TextClassifier
# ModelTrainer is a class for training and evaluating models
ModelTrainer <- flair_trainers()$ModelTrainer
We use purrr to help us split sentences using Sentence from
flair_data()
, then use map2 to add labels, and finally use
Corpus
to segment the data.
library(purrr)
data(cc_muller)
cc_muller_old <- cc_muller[1:1000,]
old_text <- map(cc_muller_old$text, Sentence)
old_labels <- as.character(cc_muller_old$class)
old_text <- map2(old_text, old_labels, ~ {
.x$add_label("classification", .y)
.x
})
set.seed(2046)
sample <- sample(c(TRUE, FALSE), length(old_text), replace=TRUE, prob=c(0.8, 0.2))
old_train <- old_text[sample]
old_test <- old_text[!sample]
test_id <- sample(c(TRUE, FALSE), length(old_test), replace=TRUE, prob=c(0.5, 0.5))
old_test <- old_test[test_id]
old_dev <- old_test[!test_id]
If you do not provide a development set (dev set) while using Flair, it will automatically split the training data into training and development datasets. The test set is used for training the model and evaluating its final performance, whereas the development set is used for adjusting model parameters and preventing overfitting, or in other words, for early stopping of the model.
old_corpus <- Corpus(train = old_train, test = old_test)
#> 2024-11-28 20:58:44,639 No dev split found. Using 10% (i.e. 80 samples) of the train split as dev data
Step 3 Load distilbert
Transformer
document_embeddings <- TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=TRUE)
First, the $make_label_dictionary
function is used to
automatically create a label dictionary for the classification task. The
label dictionary is a mapping from label to index, which is used to map
the labels to a tensor of label indices. Besides classification tasks,
flaiR also supports other label types for training custom model, such as
ner
, pos
and sentiment
. From the
cc_muller dataset: Future (seen 423 times), Present (seen 262 times),
Past (seen 131 times)
old_label_dict <- old_corpus$make_label_dictionary(label_type="classification")
#> 2024-11-28 20:58:45,759 Computing label dictionary. Progress:
#> 2024-11-28 20:58:45,763 Dictionary created for label 'classification' with 3 values: Future (seen 380 times), Present (seen 232 times), Past (seen 111 times)
TextClassifier
is used to create a text classifier. The
classifier takes the document embeddings (importing from
'distilbert-base-uncased'
from Hugging Face) and the label
dictionary as input. The label type is also specified as
classification.
old_classifier <- TextClassifier(document_embeddings,
label_dictionary = old_label_dict,
label_type='classification')
Step 4 Start Training
ModelTrainer
is used to train the model.
old_trainer <- ModelTrainer(model = old_classifier, corpus = old_corpus)
old_trainer$train("vignettes/inst/muller-campaign-communication",
learning_rate=0.02,
mini_batch_size=8L,
anneal_with_restarts = TRUE,
save_final_model=TRUE,
max_epochs=1L)
#> 2024-11-28 20:58:45,894 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,894 Model: "TextClassifier(
#> (embeddings): TransformerDocumentEmbeddings(
#> (model): DistilBertModel(
#> (embeddings): Embeddings(
#> (word_embeddings): Embedding(30523, 768, padding_idx=0)
#> (position_embeddings): Embedding(512, 768)
#> (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> (dropout): Dropout(p=0.1, inplace=False)
#> )
#> (transformer): Transformer(
#> (layer): ModuleList(
#> (0-5): 6 x TransformerBlock(
#> (attention): MultiHeadSelfAttention(
#> (dropout): Dropout(p=0.1, inplace=False)
#> (q_lin): Linear(in_features=768, out_features=768, bias=True)
#> (k_lin): Linear(in_features=768, out_features=768, bias=True)
#> (v_lin): Linear(in_features=768, out_features=768, bias=True)
#> (out_lin): Linear(in_features=768, out_features=768, bias=True)
#> )
#> (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> (ffn): FFN(
#> (dropout): Dropout(p=0.1, inplace=False)
#> (lin1): Linear(in_features=768, out_features=3072, bias=True)
#> (lin2): Linear(in_features=3072, out_features=768, bias=True)
#> (activation): GELUActivation()
#> )
#> (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> )
#> )
#> )
#> )
#> )
#> (decoder): Linear(in_features=768, out_features=3, bias=True)
#> (dropout): Dropout(p=0.0, inplace=False)
#> (locked_dropout): LockedDropout(p=0.0)
#> (word_dropout): WordDropout(p=0.0)
#> (loss_function): CrossEntropyLoss()
#> (weights): None
#> (weight_tensor) None
#> )"
#> 2024-11-28 20:58:45,894 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,894 Corpus: 723 train + 80 dev + 85 test sentences
#> 2024-11-28 20:58:45,894 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,894 Train: 723 sentences
#> 2024-11-28 20:58:45,894 (train_with_dev=False, train_with_test=False)
#> 2024-11-28 20:58:45,894 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,894 Training Params:
#> 2024-11-28 20:58:45,895 - learning_rate: "0.02"
#> 2024-11-28 20:58:45,895 - mini_batch_size: "8"
#> 2024-11-28 20:58:45,895 - max_epochs: "1"
#> 2024-11-28 20:58:45,895 - shuffle: "True"
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,895 Plugins:
#> 2024-11-28 20:58:45,895 - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,895 Final evaluation on model from best epoch (best-model.pt)
#> 2024-11-28 20:58:45,895 - metric: "('micro avg', 'f1-score')"
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,895 Computation:
#> 2024-11-28 20:58:45,895 - compute on device: cpu
#> 2024-11-28 20:58:45,895 - embedding storage: cpu
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,895 Model training base path: "vignettes/inst/muller-campaign-communication"
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:45,895 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:58:48,495 epoch 1 - iter 9/91 - loss 1.20928653 - time (sec): 2.60 - samples/sec: 27.70 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:58:51,255 epoch 1 - iter 18/91 - loss 1.01873306 - time (sec): 5.36 - samples/sec: 26.87 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:58:53,985 epoch 1 - iter 27/91 - loss 0.95770972 - time (sec): 8.09 - samples/sec: 26.70 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:58:56,915 epoch 1 - iter 36/91 - loss 0.85857756 - time (sec): 11.02 - samples/sec: 26.13 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:58:59,340 epoch 1 - iter 45/91 - loss 0.82960190 - time (sec): 13.44 - samples/sec: 26.78 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:01,961 epoch 1 - iter 54/91 - loss 0.79808861 - time (sec): 16.07 - samples/sec: 26.89 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:04,597 epoch 1 - iter 63/91 - loss 0.76187870 - time (sec): 18.70 - samples/sec: 26.95 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:07,637 epoch 1 - iter 72/91 - loss 0.74997700 - time (sec): 21.74 - samples/sec: 26.49 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:10,753 epoch 1 - iter 81/91 - loss 0.68867619 - time (sec): 24.86 - samples/sec: 26.07 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:13,309 epoch 1 - iter 90/91 - loss 0.67390569 - time (sec): 27.41 - samples/sec: 26.26 - lr: 0.020000 - momentum: 0.000000
#> 2024-11-28 20:59:13,494 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:13,494 EPOCH 1 done: loss 0.6712 - lr: 0.020000
#> 2024-11-28 20:59:14,578 DEV : loss 0.44294458627700806 - f1-score (micro avg) 0.875
#> 2024-11-28 20:59:14,580 - 0 epochs without improvement
#> 2024-11-28 20:59:14,583 saving best model
#> 2024-11-28 20:59:15,419 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:15,420 Loading model from best epoch ...
#> 2024-11-28 20:59:17,727
#> Results:
#> - F-score (micro) 0.8471
#> - F-score (macro) 0.8356
#> - Accuracy 0.8471
#>
#> By class:
#> precision recall f1-score support
#>
#> Future 0.8333 0.9302 0.8791 43
#> Present 0.8696 0.7407 0.8000 27
#> Past 0.8571 0.8000 0.8276 15
#>
#> accuracy 0.8471 85
#> macro avg 0.8533 0.8237 0.8356 85
#> weighted avg 0.8490 0.8471 0.8449 85
#>
#> 2024-11-28 20:59:17,728 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.8470588
Continue Fine-tuning with New Dataset
Now, we can continue to fine tune the already fine tuned model with
an additional 2000 pieces of data. First, let’s say we have another 2000
entries called cc_muller_new
. We can fine-tune the previous
model with these 2000 entries. The steps are the same as before. For
this case, we don’t need to split the dataset again. We can use the
entire 2000 entries as the training set and use the
old_test
set to evaluate how well our refined model
performs.
Step 1 Load the
muller-campaign-communication
Model
Load the model (old_model
) you have already fine tuned
from previous stage and let’s fine tune it with the new data,
new_corpus
.
old_model <- TextClassifier$load("vignettes/inst/muller-campaign-communication/best-model.pt")
Step 2 Convert the New Data to Sentence and Corpus
library(purrr)
cc_muller_new <- cc_muller[1001:3000,]
new_text <- map(cc_muller_new$text, Sentence)
new_labels <- as.character(cc_muller_new$class)
new_text <- map2(new_text, new_labels, ~ {
.x$add_label("classification", .y)
.x
})
new_corpus <- Corpus(train=new_text, test=old_test)
#> 2024-11-28 20:59:19,293 No dev split found. Using 10% (i.e. 200 samples) of the train split as dev data
Step 3 Create a New Model Trainer with the Old Model and New Corpus
new_trainer <- ModelTrainer(old_model, new_corpus)
Step 4 Train the New Model
new_trainer$train("vignettes/inst/new-muller-campaign-communication",
learning_rate=0.002,
mini_batch_size=8L,
max_epochs=1L)
#> 2024-11-28 20:59:19,377 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Model: "TextClassifier(
#> (embeddings): TransformerDocumentEmbeddings(
#> (model): DistilBertModel(
#> (embeddings): Embeddings(
#> (word_embeddings): Embedding(30523, 768, padding_idx=0)
#> (position_embeddings): Embedding(512, 768)
#> (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> (dropout): Dropout(p=0.1, inplace=False)
#> )
#> (transformer): Transformer(
#> (layer): ModuleList(
#> (0-5): 6 x TransformerBlock(
#> (attention): MultiHeadSelfAttention(
#> (dropout): Dropout(p=0.1, inplace=False)
#> (q_lin): Linear(in_features=768, out_features=768, bias=True)
#> (k_lin): Linear(in_features=768, out_features=768, bias=True)
#> (v_lin): Linear(in_features=768, out_features=768, bias=True)
#> (out_lin): Linear(in_features=768, out_features=768, bias=True)
#> )
#> (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> (ffn): FFN(
#> (dropout): Dropout(p=0.1, inplace=False)
#> (lin1): Linear(in_features=768, out_features=3072, bias=True)
#> (lin2): Linear(in_features=3072, out_features=768, bias=True)
#> (activation): GELUActivation()
#> )
#> (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#> )
#> )
#> )
#> )
#> )
#> (decoder): Linear(in_features=768, out_features=3, bias=True)
#> (dropout): Dropout(p=0.0, inplace=False)
#> (locked_dropout): LockedDropout(p=0.0)
#> (word_dropout): WordDropout(p=0.0)
#> (loss_function): CrossEntropyLoss()
#> (weights): None
#> (weight_tensor) None
#> )"
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Corpus: 1800 train + 200 dev + 85 test sentences
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Train: 1800 sentences
#> 2024-11-28 20:59:19,378 (train_with_dev=False, train_with_test=False)
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Training Params:
#> 2024-11-28 20:59:19,378 - learning_rate: "0.002"
#> 2024-11-28 20:59:19,378 - mini_batch_size: "8"
#> 2024-11-28 20:59:19,378 - max_epochs: "1"
#> 2024-11-28 20:59:19,378 - shuffle: "True"
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Plugins:
#> 2024-11-28 20:59:19,378 - AnnealOnPlateau | patience: '3', anneal_factor: '0.5', min_learning_rate: '0.0001'
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Final evaluation on model from best epoch (best-model.pt)
#> 2024-11-28 20:59:19,378 - metric: "('micro avg', 'f1-score')"
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Computation:
#> 2024-11-28 20:59:19,378 - compute on device: cpu
#> 2024-11-28 20:59:19,378 - embedding storage: cpu
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,378 Model training base path: "vignettes/inst/new-muller-campaign-communication"
#> 2024-11-28 20:59:19,378 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:19,379 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 20:59:26,609 epoch 1 - iter 22/225 - loss 0.50694836 - time (sec): 7.23 - samples/sec: 24.34 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 20:59:33,416 epoch 1 - iter 44/225 - loss 0.46326165 - time (sec): 14.04 - samples/sec: 25.08 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 20:59:39,774 epoch 1 - iter 66/225 - loss 0.47215962 - time (sec): 20.40 - samples/sec: 25.89 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 20:59:46,417 epoch 1 - iter 88/225 - loss 0.42928298 - time (sec): 27.04 - samples/sec: 26.04 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 20:59:52,861 epoch 1 - iter 110/225 - loss 0.42027618 - time (sec): 33.48 - samples/sec: 26.28 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 20:59:59,573 epoch 1 - iter 132/225 - loss 0.41898603 - time (sec): 40.19 - samples/sec: 26.27 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 21:00:06,474 epoch 1 - iter 154/225 - loss 0.40655317 - time (sec): 47.09 - samples/sec: 26.16 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 21:00:14,315 epoch 1 - iter 176/225 - loss 0.40555196 - time (sec): 54.94 - samples/sec: 25.63 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 21:00:21,413 epoch 1 - iter 198/225 - loss 0.40328091 - time (sec): 62.03 - samples/sec: 25.53 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 21:00:28,455 epoch 1 - iter 220/225 - loss 0.39744571 - time (sec): 69.08 - samples/sec: 25.48 - lr: 0.002000 - momentum: 0.000000
#> 2024-11-28 21:00:29,971 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 21:00:29,971 EPOCH 1 done: loss 0.3989 - lr: 0.002000
#> 2024-11-28 21:00:33,145 DEV : loss 0.44265758991241455 - f1-score (micro avg) 0.85
#> 2024-11-28 21:00:33,150 - 0 epochs without improvement
#> 2024-11-28 21:00:33,151 saving best model
#> 2024-11-28 21:00:33,994 ----------------------------------------------------------------------------------------------------
#> 2024-11-28 21:00:33,995 Loading model from best epoch ...
#> 2024-11-28 21:00:36,286
#> Results:
#> - F-score (micro) 0.8824
#> - F-score (macro) 0.8804
#> - Accuracy 0.8824
#>
#> By class:
#> precision recall f1-score support
#>
#> Future 0.9268 0.8837 0.9048 43
#> Present 0.7812 0.9259 0.8475 27
#> Past 1.0000 0.8000 0.8889 15
#>
#> accuracy 0.8824 85
#> macro avg 0.9027 0.8699 0.8804 85
#> weighted avg 0.8935 0.8824 0.8838 85
#>
#> 2024-11-28 21:00:36,286 ----------------------------------------------------------------------------------------------------
#> $test_score
#> [1] 0.8823529
Model Performance Metrics: Pre and Post Fine-tuning
After fine-tuning for 1 epoch, the model showed improved performance on the same test set.
Evaluation Metric | Pre-finetune | Post-finetune | Improvement |
---|---|---|---|
F-score (micro) | 0.7294 | 0.8471 | +0.1177 |
F-score (macro) | 0.7689 | 0.8583 | +0.0894 |
Accuracy | 0.7294 | 0.8471 | +0.1177 |
More R tutorial and documentation see here.