{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Intent Classification in Banking 🏦\n",
"\n",
"BANKING77 dataset provides a very fine-grained set of intents in a banking domain. It comprises 13,083 customer service queries labeled with 77 intents. It focuses on fine-grained single-domain intent detection."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
"import pandas as pd\n",
"\n",
"raw_dataset_train = load_dataset(\"banking77\", split=\"train\")\n",
"raw_dataset_val = load_dataset(\"banking77\", split=\"test\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_train_df = raw_dataset_train.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" I am still waiting on my card? \n",
" 11 \n",
" \n",
" \n",
" 1 \n",
" What can I do if my card still hasn't arrived after 2 weeks? \n",
" 11 \n",
" \n",
" \n",
" 2 \n",
" I have been waiting over a week. Is the card still coming? \n",
" 11 \n",
" \n",
" \n",
" 3 \n",
" Can I track my card while it is in the process of delivery? \n",
" 11 \n",
" \n",
" \n",
" 4 \n",
" How do I know if I will get my card, or if it is lost? \n",
" 11 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text label\n",
"0 I am still waiting on my card? 11\n",
"1 What can I do if my card still hasn't arrived after 2 weeks? 11\n",
"2 I have been waiting over a week. Is the card still coming? 11\n",
"3 Can I track my card while it is in the process of delivery? 11\n",
"4 How do I know if I will get my card, or if it is lost? 11"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.set_option('display.max_colwidth', None)\n",
"\n",
"raw_dataset_train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" 6883 \n",
" Is it possible for me to change my PIN number? \n",
" 21 \n",
" \n",
" \n",
" 5836 \n",
" I'm not sure why my card didn't work \n",
" 25 \n",
" \n",
" \n",
" 8601 \n",
" I don't think my top up worked \n",
" 59 \n",
" \n",
" \n",
" 2545 \n",
" Can you explain why my payment was charged a fee? \n",
" 15 \n",
" \n",
" \n",
" 8697 \n",
" How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay \n",
" 5 \n",
" \n",
" \n",
" 5573 \n",
" Why am I getting declines when trying to make a purchase online? \n",
" 27 \n",
" \n",
" \n",
" 576 \n",
" What is the $1 transaction on my account? \n",
" 34 \n",
" \n",
" \n",
" 6832 \n",
" It looks like my card payment was sent back. \n",
" 53 \n",
" \n",
" \n",
" 7111 \n",
" Why am I unable to transfer money when I was able to before? \n",
" 7 \n",
" \n",
" \n",
" 439 \n",
" What if there is an error on the exchange rate? \n",
" 17 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text \\\n",
"6883 Is it possible for me to change my PIN number? \n",
"5836 I'm not sure why my card didn't work \n",
"8601 I don't think my top up worked \n",
"2545 Can you explain why my payment was charged a fee? \n",
"8697 How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay \n",
"5573 Why am I getting declines when trying to make a purchase online? \n",
"576 What is the $1 transaction on my account? \n",
"6832 It looks like my card payment was sent back. \n",
"7111 Why am I unable to transfer money when I was able to before? \n",
"439 What if there is an error on the exchange rate? \n",
"\n",
" label \n",
"6883 21 \n",
"5836 25 \n",
"8601 59 \n",
"2545 15 \n",
"8697 5 \n",
"5573 27 \n",
"576 34 \n",
"6832 53 \n",
"7111 7 \n",
"439 17 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_dataset_train_df.sample(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It seems we have a lot of classes..."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"raw_dataset_train_df.label.value_counts().plot(kind='bar', figsize=(15, 10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is the list of classes with their definition:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| label | intent (category) |\n",
"|---:|:-------------------------------------------------|\n",
"| 0 | activate_my_card |\n",
"| 1 | age_limit |\n",
"| 2 | apple_pay_or_google_pay |\n",
"| 3 | atm_support |\n",
"| 4 | automatic_top_up |\n",
"| 5 | balance_not_updated_after_bank_transfer |\n",
"| 6 | balance_not_updated_after_cheque_or_cash_deposit |\n",
"| 7 | beneficiary_not_allowed |\n",
"| 8 | cancel_transfer |\n",
"| 9 | card_about_to_expire |\n",
"| 10 | card_acceptance |\n",
"| 11 | card_arrival |\n",
"| 12 | card_delivery_estimate |\n",
"| 13 | card_linking |\n",
"| 14 | card_not_working |\n",
"| 15 | card_payment_fee_charged |\n",
"| 16 | card_payment_not_recognised |\n",
"| 17 | card_payment_wrong_exchange_rate |\n",
"| 18 | card_swallowed |\n",
"| 19 | cash_withdrawal_charge |\n",
"| 20 | cash_withdrawal_not_recognised |\n",
"| 21 | change_pin |\n",
"| 22 | compromised_card |\n",
"| 23 | contactless_not_working |\n",
"| 24 | country_support |\n",
"| 25 | declined_card_payment |\n",
"| 26 | declined_cash_withdrawal |\n",
"| 27 | declined_transfer |\n",
"| 28 | direct_debit_payment_not_recognised |\n",
"| 29 | disposable_card_limits |\n",
"| 30 | edit_personal_details |\n",
"| 31 | exchange_charge |\n",
"| 32 | exchange_rate |\n",
"| 33 | exchange_via_app |\n",
"| 34 | extra_charge_on_statement |\n",
"| 35 | failed_transfer |\n",
"| 36 | fiat_currency_support |\n",
"| 37 | get_disposable_virtual_card |\n",
"| 38 | get_physical_card |\n",
"| 39 | getting_spare_card |\n",
"| 40 | getting_virtual_card |\n",
"| 41 | lost_or_stolen_card |\n",
"| 42 | lost_or_stolen_phone |\n",
"| 43 | order_physical_card |\n",
"| 44 | passcode_forgotten |\n",
"| 45 | pending_card_payment |\n",
"| 46 | pending_cash_withdrawal |\n",
"| 47 | pending_top_up |\n",
"| 48 | pending_transfer |\n",
"| 49 | pin_blocked |\n",
"| 50 | receiving_money |\n",
"| 51 | Refund_not_showing_up |\n",
"| 52 | request_refund |\n",
"| 53 | reverted_card_payment? |\n",
"| 54 | supported_cards_and_currencies |\n",
"| 55 | terminate_account |\n",
"| 56 | top_up_by_bank_transfer_charge |\n",
"| 57 | top_up_by_card_charge |\n",
"| 58 | top_up_by_cash_or_cheque |\n",
"| 59 | top_up_failed |\n",
"| 60 | top_up_limits |\n",
"| 61 | top_up_reverted |\n",
"| 62 | topping_up_by_card |\n",
"| 63 | transaction_charged_twice |\n",
"| 64 | transfer_fee_charged |\n",
"| 65 | transfer_into_account |\n",
"| 66 | transfer_not_received_by_recipient |\n",
"| 67 | transfer_timing |\n",
"| 68 | unable_to_verify_identity |\n",
"| 69 | verify_my_identity |\n",
"| 70 | verify_source_of_funds |\n",
"| 71 | verify_top_up |\n",
"| 72 | virtual_card_not_working |\n",
"| 73 | visa_or_mastercard |\n",
"| 74 | why_verify_identity |\n",
"| 75 | wrong_amount_of_cash_received |\n",
"| 76 | wrong_exchange_rate_for_cash_withdrawal |\n",
"\n",
"And this is a summary of the dataset\n",
"\n",
"| Dataset statistics | Train | Test |\n",
"| --- | --- | --- |\n",
"| Number of examples | 10 003 | 3 080 |\n",
"| Average character length | 59.5 | 54.2 |\n",
"| Number of intents | 77 | 77 |"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"bert-base-uncased\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"\n",
"def tokenize_function(example):\n",
" return tokenizer(example[\"text\"], truncation=True)\n",
"\n",
"\n",
"tokenized_datasets_train = raw_dataset_train.map(tokenize_function, batched=True)\n",
"tokenized_datasets_val = raw_dataset_val.map(tokenize_function, batched=True)\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\"intent-banking\", per_device_train_batch_size=16)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=77) # watch out for the number of labels!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model,\n",
" training_args,\n",
" train_dataset=tokenized_datasets_train,\n",
" eval_dataset=tokenized_datasets_val,\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fcb99442c8174f91a78ddaf8ec9ba4f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1878 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Checkpoint destination directory test-trainer/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 1.4679, 'grad_norm': 9.389148712158203, 'learning_rate': 3.668796592119276e-05, 'epoch': 0.8}\n",
"{'loss': 0.4518, 'grad_norm': 7.393991947174072, 'learning_rate': 2.3375931842385517e-05, 'epoch': 1.6}\n",
"{'loss': 0.2117, 'grad_norm': 5.330203056335449, 'learning_rate': 1.0063897763578276e-05, 'epoch': 2.4}\n",
"{'train_runtime': 582.5736, 'train_samples_per_second': 51.511, 'train_steps_per_second': 3.224, 'train_loss': 0.5940346174473706, 'epoch': 3.0}\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=1878, training_loss=0.5940346174473706, metrics={'train_runtime': 582.5736, 'train_samples_per_second': 51.511, 'train_steps_per_second': 3.224, 'train_loss': 0.5940346174473706, 'epoch': 3.0})"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ce36c1632c1b4ccfbcccd9b165b3c844",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/385 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(3080, 77) (3080,)\n"
]
}
],
"source": [
"predictions = trainer.predict(tokenized_datasets_val)\n",
"print(predictions.predictions.shape, predictions.label_ids.shape)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([41, 11, 11, ..., 24, 24, 24])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions.predictions.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_val_df = raw_dataset_val.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_val_df['prediction'] = predictions.predictions.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.97 0.99 40\n",
" 1 0.98 1.00 0.99 40\n",
" 2 1.00 1.00 1.00 40\n",
" 3 1.00 0.97 0.99 40\n",
" 4 0.95 0.93 0.94 40\n",
" 5 0.84 0.80 0.82 40\n",
" 6 1.00 0.93 0.96 40\n",
" 7 0.97 0.93 0.95 40\n",
" 8 1.00 0.95 0.97 40\n",
" 9 0.95 1.00 0.98 40\n",
" 10 0.93 0.93 0.93 40\n",
" 11 0.88 0.88 0.88 40\n",
" 12 0.90 0.88 0.89 40\n",
" 13 0.97 0.95 0.96 40\n",
" 14 0.87 0.97 0.92 40\n",
" 15 0.84 0.93 0.88 40\n",
" 16 0.90 0.95 0.93 40\n",
" 17 0.88 0.95 0.92 40\n",
" 18 1.00 0.95 0.97 40\n",
" 19 0.95 0.95 0.95 40\n",
" 20 0.83 0.97 0.90 40\n",
" 21 0.89 1.00 0.94 40\n",
" 22 0.94 0.78 0.85 40\n",
" 23 1.00 0.88 0.93 40\n",
" 24 0.93 0.95 0.94 40\n",
" 25 0.85 0.97 0.91 40\n",
" 26 0.78 1.00 0.88 40\n",
" 27 1.00 0.75 0.86 40\n",
" 28 0.90 0.88 0.89 40\n",
" 29 0.97 0.90 0.94 40\n",
" 30 1.00 1.00 1.00 40\n",
" 31 0.97 0.93 0.95 40\n",
" 32 0.93 0.97 0.95 40\n",
" 33 0.88 0.95 0.92 40\n",
" 34 1.00 0.95 0.97 40\n",
" 35 0.87 0.97 0.92 40\n",
" 36 0.95 0.90 0.92 40\n",
" 37 0.94 0.85 0.89 40\n",
" 38 0.97 0.97 0.97 40\n",
" 39 0.93 0.97 0.95 40\n",
" 40 0.87 0.97 0.92 40\n",
" 41 0.86 0.95 0.90 40\n",
" 42 1.00 0.97 0.99 40\n",
" 43 0.90 0.95 0.93 40\n",
" 44 1.00 1.00 1.00 40\n",
" 45 0.95 0.95 0.95 40\n",
" 46 1.00 0.97 0.99 40\n",
" 47 0.91 0.97 0.94 40\n",
" 48 0.89 0.80 0.84 40\n",
" 49 0.97 0.85 0.91 40\n",
" 50 0.95 0.93 0.94 40\n",
" 51 1.00 1.00 1.00 40\n",
" 52 1.00 0.97 0.99 40\n",
" 53 0.93 0.93 0.93 40\n",
" 54 0.89 0.97 0.93 40\n",
" 55 0.98 1.00 0.99 40\n",
" 56 0.92 0.90 0.91 40\n",
" 57 0.93 0.95 0.94 40\n",
" 58 0.95 0.95 0.95 40\n",
" 59 0.92 0.90 0.91 40\n",
" 60 1.00 0.97 0.99 40\n",
" 61 0.92 0.85 0.88 40\n",
" 62 0.86 0.80 0.83 40\n",
" 63 0.93 1.00 0.96 40\n",
" 64 0.95 0.93 0.94 40\n",
" 65 0.95 0.88 0.91 40\n",
" 66 0.92 0.90 0.91 40\n",
" 67 0.78 0.95 0.85 40\n",
" 68 0.95 0.93 0.94 40\n",
" 69 0.81 0.95 0.87 40\n",
" 70 1.00 1.00 1.00 40\n",
" 71 1.00 1.00 1.00 40\n",
" 72 1.00 0.93 0.96 40\n",
" 73 1.00 0.93 0.96 40\n",
" 74 0.91 0.80 0.85 40\n",
" 75 1.00 0.90 0.95 40\n",
" 76 0.95 0.88 0.91 40\n",
"\n",
" accuracy 0.93 3080\n",
" macro avg 0.94 0.93 0.93 3080\n",
"weighted avg 0.94 0.93 0.93 3080\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(raw_dataset_val_df['label'], raw_dataset_val_df['prediction']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We achieve around 93% accuracy and 93% F1-score... Not so bad for a classification problem with 77 classes!!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparing to scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.90 0.95 40\n",
" 1 0.93 0.97 0.95 40\n",
" 2 0.97 0.97 0.97 40\n",
" 3 0.90 0.93 0.91 40\n",
" 4 0.97 0.90 0.94 40\n",
" 5 0.64 0.75 0.69 40\n",
" 6 0.85 0.88 0.86 40\n",
" 7 0.89 0.82 0.86 40\n",
" 8 0.89 0.97 0.93 40\n",
" 9 1.00 0.97 0.99 40\n",
" 10 0.84 0.53 0.65 40\n",
" 11 0.75 0.90 0.82 40\n",
" 12 0.75 0.82 0.79 40\n",
" 13 0.86 0.93 0.89 40\n",
" 14 0.63 0.82 0.72 40\n",
" 15 0.78 0.90 0.84 40\n",
" 16 0.64 0.72 0.68 40\n",
" 17 0.90 0.90 0.90 40\n",
" 18 1.00 0.88 0.93 40\n",
" 19 0.90 0.88 0.89 40\n",
" 20 0.66 0.78 0.71 40\n",
" 21 0.97 0.93 0.95 40\n",
" 22 0.82 0.80 0.81 40\n",
" 23 1.00 0.82 0.90 40\n",
" 24 0.90 0.90 0.90 40\n",
" 25 0.76 0.80 0.78 40\n",
" 26 0.82 0.80 0.81 40\n",
" 27 0.94 0.75 0.83 40\n",
" 28 0.89 0.78 0.83 40\n",
" 29 0.81 0.72 0.76 40\n",
" 30 0.93 0.95 0.94 40\n",
" 31 0.90 0.88 0.89 40\n",
" 32 0.89 1.00 0.94 40\n",
" 33 0.78 0.90 0.84 40\n",
" 34 0.79 0.85 0.82 40\n",
" 35 0.72 0.85 0.78 40\n",
" 36 0.97 0.70 0.81 40\n",
" 37 0.60 0.75 0.67 40\n",
" 38 0.87 1.00 0.93 40\n",
" 39 0.77 0.68 0.72 40\n",
" 40 0.72 0.97 0.83 40\n",
" 41 0.89 0.80 0.84 40\n",
" 42 0.97 0.93 0.95 40\n",
" 43 0.65 0.80 0.72 40\n",
" 44 1.00 1.00 1.00 40\n",
" 45 0.85 0.82 0.84 40\n",
" 46 0.94 0.78 0.85 40\n",
" 47 0.62 0.90 0.73 40\n",
" 48 0.85 0.57 0.69 40\n",
" 49 1.00 0.85 0.92 40\n",
" 50 0.89 0.80 0.84 40\n",
" 51 0.82 0.93 0.87 40\n",
" 52 0.74 0.78 0.76 40\n",
" 53 0.71 0.90 0.79 40\n",
" 54 0.71 0.90 0.79 40\n",
" 55 0.93 0.95 0.94 40\n",
" 56 0.96 0.62 0.76 40\n",
" 57 0.92 0.88 0.90 40\n",
" 58 0.88 0.72 0.79 40\n",
" 59 0.69 0.78 0.73 40\n",
" 60 0.94 0.80 0.86 40\n",
" 61 0.81 0.75 0.78 40\n",
" 62 0.89 0.60 0.72 40\n",
" 63 0.88 0.93 0.90 40\n",
" 64 0.74 0.93 0.82 40\n",
" 65 0.78 0.88 0.82 40\n",
" 66 0.72 0.78 0.75 40\n",
" 67 0.77 0.75 0.76 40\n",
" 68 0.90 0.68 0.77 40\n",
" 69 0.64 0.62 0.63 40\n",
" 70 0.83 1.00 0.91 40\n",
" 71 0.95 1.00 0.98 40\n",
" 72 0.92 0.28 0.42 40\n",
" 73 1.00 0.93 0.96 40\n",
" 74 0.67 0.72 0.70 40\n",
" 75 0.88 0.90 0.89 40\n",
" 76 0.94 0.75 0.83 40\n",
"\n",
" accuracy 0.83 3080\n",
" macro avg 0.84 0.83 0.83 3080\n",
"weighted avg 0.84 0.83 0.83 3080\n",
"\n"
]
}
],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import classification_report\n",
"\n",
"model = Pipeline([\n",
" (\"vectorizer\", TfidfVectorizer(stop_words=\"english\")),\n",
" (\"classifier\", LogisticRegression())\n",
"])\n",
"\n",
"model.fit(raw_dataset_train['text'], raw_dataset_train['label'])\n",
"\n",
"y_pred = model.predict(raw_dataset_val['text'])\n",
"\n",
"print(classification_report(raw_dataset_val['label'], y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}