{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Intent Classification in Banking 🏦\n", "\n", "BANKING77 dataset provides a very fine-grained set of intents in a banking domain. It comprises 13,083 customer service queries labeled with 77 intents. It focuses on fine-grained single-domain intent detection." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, DataCollatorWithPadding\n", "import pandas as pd\n", "\n", "raw_dataset_train = load_dataset(\"banking77\", split=\"train\")\n", "raw_dataset_val = load_dataset(\"banking77\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "raw_dataset_train_df = raw_dataset_train.to_pandas()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
0I am still waiting on my card?11
1What can I do if my card still hasn't arrived after 2 weeks?11
2I have been waiting over a week. Is the card still coming?11
3Can I track my card while it is in the process of delivery?11
4How do I know if I will get my card, or if it is lost?11
\n", "
" ], "text/plain": [ " text label\n", "0 I am still waiting on my card? 11\n", "1 What can I do if my card still hasn't arrived after 2 weeks? 11\n", "2 I have been waiting over a week. Is the card still coming? 11\n", "3 Can I track my card while it is in the process of delivery? 11\n", "4 How do I know if I will get my card, or if it is lost? 11" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.set_option('display.max_colwidth', None)\n", "\n", "raw_dataset_train_df.head()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
6883Is it possible for me to change my PIN number?21
5836I'm not sure why my card didn't work25
8601I don't think my top up worked59
2545Can you explain why my payment was charged a fee?15
8697How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay5
5573Why am I getting declines when trying to make a purchase online?27
576What is the $1 transaction on my account?34
6832It looks like my card payment was sent back.53
7111Why am I unable to transfer money when I was able to before?7
439What if there is an error on the exchange rate?17
\n", "
" ], "text/plain": [ " text \\\n", "6883 Is it possible for me to change my PIN number? \n", "5836 I'm not sure why my card didn't work \n", "8601 I don't think my top up worked \n", "2545 Can you explain why my payment was charged a fee? \n", "8697 How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay \n", "5573 Why am I getting declines when trying to make a purchase online? \n", "576 What is the $1 transaction on my account? \n", "6832 It looks like my card payment was sent back. \n", "7111 Why am I unable to transfer money when I was able to before? \n", "439 What if there is an error on the exchange rate? \n", "\n", " label \n", "6883 21 \n", "5836 25 \n", "8601 59 \n", "2545 15 \n", "8697 5 \n", "5573 27 \n", "576 34 \n", "6832 53 \n", "7111 7 \n", "439 17 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_dataset_train_df.sample(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It seems we have a lot of classes..." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "raw_dataset_train_df.label.value_counts().plot(kind='bar', figsize=(15, 10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the list of classes with their definition:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| label | intent (category) |\n", "|---:|:-------------------------------------------------|\n", "| 0 | activate_my_card |\n", "| 1 | age_limit |\n", "| 2 | apple_pay_or_google_pay |\n", "| 3 | atm_support |\n", "| 4 | automatic_top_up |\n", "| 5 | balance_not_updated_after_bank_transfer |\n", "| 6 | balance_not_updated_after_cheque_or_cash_deposit |\n", "| 7 | beneficiary_not_allowed |\n", "| 8 | cancel_transfer |\n", "| 9 | card_about_to_expire |\n", "| 10 | card_acceptance |\n", "| 11 | card_arrival |\n", "| 12 | card_delivery_estimate |\n", "| 13 | card_linking |\n", "| 14 | card_not_working |\n", "| 15 | card_payment_fee_charged |\n", "| 16 | card_payment_not_recognised |\n", "| 17 | card_payment_wrong_exchange_rate |\n", "| 18 | card_swallowed |\n", "| 19 | cash_withdrawal_charge |\n", "| 20 | cash_withdrawal_not_recognised |\n", "| 21 | change_pin |\n", "| 22 | compromised_card |\n", "| 23 | contactless_not_working |\n", "| 24 | country_support |\n", "| 25 | declined_card_payment |\n", "| 26 | declined_cash_withdrawal |\n", "| 27 | declined_transfer |\n", "| 28 | direct_debit_payment_not_recognised |\n", "| 29 | disposable_card_limits |\n", "| 30 | edit_personal_details |\n", "| 31 | exchange_charge |\n", "| 32 | exchange_rate |\n", "| 33 | exchange_via_app |\n", "| 34 | extra_charge_on_statement |\n", "| 35 | failed_transfer |\n", "| 36 | fiat_currency_support |\n", "| 37 | get_disposable_virtual_card |\n", "| 38 | get_physical_card |\n", "| 39 | getting_spare_card |\n", "| 40 | getting_virtual_card |\n", "| 41 | lost_or_stolen_card |\n", "| 42 | lost_or_stolen_phone |\n", "| 43 | order_physical_card |\n", "| 44 | passcode_forgotten |\n", "| 45 | pending_card_payment |\n", "| 46 | pending_cash_withdrawal |\n", "| 47 | pending_top_up |\n", "| 48 | pending_transfer |\n", "| 49 | pin_blocked |\n", "| 50 | receiving_money |\n", "| 51 | Refund_not_showing_up |\n", "| 52 | request_refund |\n", "| 53 | reverted_card_payment? |\n", "| 54 | supported_cards_and_currencies |\n", "| 55 | terminate_account |\n", "| 56 | top_up_by_bank_transfer_charge |\n", "| 57 | top_up_by_card_charge |\n", "| 58 | top_up_by_cash_or_cheque |\n", "| 59 | top_up_failed |\n", "| 60 | top_up_limits |\n", "| 61 | top_up_reverted |\n", "| 62 | topping_up_by_card |\n", "| 63 | transaction_charged_twice |\n", "| 64 | transfer_fee_charged |\n", "| 65 | transfer_into_account |\n", "| 66 | transfer_not_received_by_recipient |\n", "| 67 | transfer_timing |\n", "| 68 | unable_to_verify_identity |\n", "| 69 | verify_my_identity |\n", "| 70 | verify_source_of_funds |\n", "| 71 | verify_top_up |\n", "| 72 | virtual_card_not_working |\n", "| 73 | visa_or_mastercard |\n", "| 74 | why_verify_identity |\n", "| 75 | wrong_amount_of_cash_received |\n", "| 76 | wrong_exchange_rate_for_cash_withdrawal |\n", "\n", "And this is a summary of the dataset\n", "\n", "| Dataset statistics | Train | Test |\n", "| --- | --- | --- |\n", "| Number of examples | 10 003 | 3 080 |\n", "| Average character length | 59.5 | 54.2 |\n", "| Number of intents | 77 | 77 |" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_name = \"bert-base-uncased\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "\n", "def tokenize_function(example):\n", " return tokenizer(example[\"text\"], truncation=True)\n", "\n", "\n", "tokenized_datasets_train = raw_dataset_train.map(tokenize_function, batched=True)\n", "tokenized_datasets_val = raw_dataset_val.map(tokenize_function, batched=True)\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "training_args = TrainingArguments(\"intent-banking\", per_device_train_batch_size=16)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=77) # watch out for the number of labels!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "from transformers import Trainer\n", "\n", "trainer = Trainer(\n", " model,\n", " training_args,\n", " train_dataset=tokenized_datasets_train,\n", " eval_dataset=tokenized_datasets_val,\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fcb99442c8174f91a78ddaf8ec9ba4f3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1878 [00:00