Skip to main content

Text Classification

Text classification is the task of assigning a label to a piece of text. This can be used for a variety of tasks, such as sentiment analysis, spam detection, and more. We support two types of text classification models: Traditional and NLI.

Traditional classifiers are trained on labeled text data, where each input is associated with a category label. The model learns to predict the category of new inputs based on the patterns it has learned from the training data.

NLI (Natural Language Inference) models are trained on labeled pairs of text, where each pair consists of a premise and a hypothesis. The model learns to predict the relationship between the premise and hypothesis, outputing a likelihood score that the hypothesis is supported by the premise.

While helpful as a standalone model, we have provided a number of tools to help you use better leverage the strengths of NLI models in your applications. These are:

  • Entity Matching: Outputs a likelihood score that a piece of text matches a given criteria.
  • Hierarchy Categorisation: Traverses a hierarchy of predefined categories to output the most likely category and subcategory for a given piece of text.

Currently NLI models are only supported for inference, and not for training.

Train a traditional text classifier

from truestate import datasets, models

# get the dataset
dataset = datasets.get(name="my-dataset")

# train the model
model = models.TextClassifier(
name="my-model",
)

model.train(
dataset=dataset
input_column="text",
target_column="label",
)

Apply a text classifier to new data

from truestate import models, datasets

# get the model
model = models.get(name="my-model")

# get the dataset
dataset = datasets.get(name="my-dataset")

# apply the model to the dataset
predictions = model.inference(
dataset=dataset,
input_column="text",
)

Apply an NLI model in entity matching

from truestate import datasets, models

# get the model
model = models.EntityMatcher()

# load the data
dataset = datasets.get(name="my-dataset")

# define the categories to match
choices = [
"this text contains financial advice"
"this text does not contain financial advice",
]

predictions = model.inference(
dataset=dataset,
input_column="text",
choices=choices,
)

Apply an NLI model in hierarchy categorisation

from truestate import models

# get the model
model = models.HierarchyCategoriser()

# apply the model to a piece of text
hierarchy = [
{
"category": "Utilities",
"hypothesis": "This company is a utiliity company",
"subcategories": [
{
"category": "Energy",
"hypothesis": "This company is an energy company",
"subcategories": [
{
"category": "Solar",
"hypothesis": "This company is a renewable energy company",
},
{
"category": "Wind",
"hypothesis": "This company is a wind energy company",
},
{
"category": "Hydro",
"hypothesis": "This company is a hydro energy company",
},
{
"category": "Nuclear",
"hypothesis": "This company is a nuclear energy company",
},
{
"category": "Coal",
"hypothesis": "This company is a fossil fuel company",
},
}
]
},
{
"category": "Technology",
"hypothesis": "This company is a technology company",
"subcategories": [
{
"category": "Software",
"hypothesis": "This company is a software company",
},
{
"category": "Hardware",
"hypothesis": "This company is a hardware company",
},
{
"category": "Internet",
"hypothesis": "This company is an internet company",
},
{
"category": "Telecom",
"hypothesis": "This company is a telecom company",
},
]
},
{
"category": "Finance",
"hypothesis": "This company is a finance company",
"subcategories": [
{
"category": "Banking",
"hypothesis": "This company is a banking company",
},
{
"category": "Insurance",
"hypothesis": "This company is an insurance company",
},
{
"category": "Investment",
"hypothesis": "This company is an investment company",
},
{
"category": "Real Estate",
"hypothesis": "This company is a real estate company",
},
]
},
]

results = hierarchy_classification(
text=text,
categories=hierarchy
)

predictions = model.inference(
input_column="text",
choices=choices,
)