{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Intent Classification in Banking 🏦\n",
"\n",
"BANKING77 dataset provides a very fine-grained set of intents in a banking domain. It comprises 13,083 customer service queries labeled with 77 intents. It focuses on fine-grained single-domain intent detection."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
"import pandas as pd\n",
"\n",
"raw_dataset_train = load_dataset(\"banking77\", split=\"train\")\n",
"raw_dataset_val = load_dataset(\"banking77\", split=\"test\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_train_df = raw_dataset_train.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" I am still waiting on my card? \n",
" 11 \n",
" \n",
" \n",
" 1 \n",
" What can I do if my card still hasn't arrived after 2 weeks? \n",
" 11 \n",
" \n",
" \n",
" 2 \n",
" I have been waiting over a week. Is the card still coming? \n",
" 11 \n",
" \n",
" \n",
" 3 \n",
" Can I track my card while it is in the process of delivery? \n",
" 11 \n",
" \n",
" \n",
" 4 \n",
" How do I know if I will get my card, or if it is lost? \n",
" 11 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text label\n",
"0 I am still waiting on my card? 11\n",
"1 What can I do if my card still hasn't arrived after 2 weeks? 11\n",
"2 I have been waiting over a week. Is the card still coming? 11\n",
"3 Can I track my card while it is in the process of delivery? 11\n",
"4 How do I know if I will get my card, or if it is lost? 11"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.set_option('display.max_colwidth', None)\n",
"\n",
"raw_dataset_train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" label \n",
" \n",
" \n",
" \n",
" \n",
" 6883 \n",
" Is it possible for me to change my PIN number? \n",
" 21 \n",
" \n",
" \n",
" 5836 \n",
" I'm not sure why my card didn't work \n",
" 25 \n",
" \n",
" \n",
" 8601 \n",
" I don't think my top up worked \n",
" 59 \n",
" \n",
" \n",
" 2545 \n",
" Can you explain why my payment was charged a fee? \n",
" 15 \n",
" \n",
" \n",
" 8697 \n",
" How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay \n",
" 5 \n",
" \n",
" \n",
" 5573 \n",
" Why am I getting declines when trying to make a purchase online? \n",
" 27 \n",
" \n",
" \n",
" 576 \n",
" What is the $1 transaction on my account? \n",
" 34 \n",
" \n",
" \n",
" 6832 \n",
" It looks like my card payment was sent back. \n",
" 53 \n",
" \n",
" \n",
" 7111 \n",
" Why am I unable to transfer money when I was able to before? \n",
" 7 \n",
" \n",
" \n",
" 439 \n",
" What if there is an error on the exchange rate? \n",
" 17 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text \\\n",
"6883 Is it possible for me to change my PIN number? \n",
"5836 I'm not sure why my card didn't work \n",
"8601 I don't think my top up worked \n",
"2545 Can you explain why my payment was charged a fee? \n",
"8697 How long does a transfer from a UK account take? I just made one and it doesn't seem to be working, wondering if everything is okay \n",
"5573 Why am I getting declines when trying to make a purchase online? \n",
"576 What is the $1 transaction on my account? \n",
"6832 It looks like my card payment was sent back. \n",
"7111 Why am I unable to transfer money when I was able to before? \n",
"439 What if there is an error on the exchange rate? \n",
"\n",
" label \n",
"6883 21 \n",
"5836 25 \n",
"8601 59 \n",
"2545 15 \n",
"8697 5 \n",
"5573 27 \n",
"576 34 \n",
"6832 53 \n",
"7111 7 \n",
"439 17 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_dataset_train_df.sample(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It seems we have a lot of classes..."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAJBCAYAAADV610zAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtfUlEQVR4nO3dfbwtd10f+s+XHEQjEkJyBBTCAQxQtBL0NHALFCrWBtPLg6WWeC8CihEVsJZ7NT7cxlq1uVSgelX6ChAetDyKPNiAgkilWnkIBJJgeIyHEhqSAyhQ6UsJ/O4fM6esrLP2OTN77X327+z9fr9e67Vnzfqtmd+a+f1m5rNm9qxqrQUAAIA+3WqnKwAAAMDGhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0bN9OVyBJzjzzzHbgwIGdrgYAAMCOePe73/2p1tr+Va91EdoOHDiQK664YqerAQAAsCOq6mMbvebySAAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMf27XQFlh246PKjxh265PwdqAkAAMDOc6YNAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADq2b6crsI4DF11+1LhDl5y/AzUBAADYHs60AQAAdExoAwAA6NhJfXnkHC6lBAAATkbOtAEAAHRMaAMAAOjYnrk8cg6XUgIAAL1wpg0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomN9pW5PfdAMAALaTM20AAAAdO25oq6rLquqmqrpmYdwrquq94+NQVb13HH+gqv7nwmv/YRvrDgAAsOtNuTzyRUl+PclLjoxorf3zI8NV9awkn10o/9HW2jlbVD8AAIA97bihrbX2tqo6sOq1qqok35vkO7a4XgAAAGT9/2l7SJIbW2sfXhh396q6sqr+uKoesub0AQAA9rR17x55QZKXLTy/IclZrbVPV9W3J3ltVX1za+1zy2+sqguTXJgkZ5111prV6N+qu0wm7jQJAAAc26bPtFXVviTfk+QVR8a11v6mtfbpcfjdST6a5F6r3t9au7S1drC1dnD//v2brQYAAMCuts7lkd+Z5AOtteuPjKiq/VV1yjh8jyRnJ7luvSoCAADsXVNu+f+yJH+W5N5VdX1V/eD40uNyy0sjk+QfJLlq/AmA30nylNbaZ7awvgAAAHvKlLtHXrDB+CeuGPfqJK9ev1oAAAAk69+IhG3gpiUAAMAR697yHwAAgG0ktAEAAHTM5ZEnOZdSAgDA7uZMGwAAQMeENgAAgI65PHIPWXUp5UaXUc4pCwAAbB9n2gAAADomtAEAAHRMaAMAAOiY0AYAANAxNyJhbW5aAgAA28eZNgAAgI4JbQAAAB1zeSQnlEspAQBgHmfaAAAAOia0AQAAdMzlkXRp1WWUiUspAQDYe5xpAwAA6JgzbZz0nJUDAGA3c6YNAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DE/rs2e4oe4AQA42TjTBgAA0DGhDQAAoGMuj4QNuJQSAIAeONMGAADQMWfaYAusOivnjBwAAFvBmTYAAICOCW0AAAAdc3kknGAupQQAYA5n2gAAADomtAEAAHTM5ZHQsTmXUrrsEgBgd3KmDQAAoGNCGwAAQMeENgAAgI4JbQAAAB1zIxLYY1bdsCRx0xIAgF450wYAANAxoQ0AAKBjLo8ENuRSSgCAnedMGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxv9MGbIk5v+m2qqzffgMAWM2ZNgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOjYcUNbVV1WVTdV1TUL436+qj5RVe8dH9+98NpPV9VHquqDVfWPt6viAAAAe8GUM20vSnLeivHPaa2dMz7ekCRVdd8kj0vyzeN7frOqTtmqygIAAOw1xw1trbW3JfnMxOk9KsnLW2t/01r7iyQfSXLuGvUDAADY09b5n7anVtVV4+WTp4/jvjHJxxfKXD+OO0pVXVhVV1TVFYcPH16jGgAAALvXZkPbc5PcM8k5SW5I8qy5E2itXdpaO9haO7h///5NVgMAAGB321Roa63d2Fr7Umvty0mel69cAvmJJHddKHqXcRwAAACbsG8zb6qqO7fWbhifPibJkTtLvj7JS6vq2Um+IcnZSd65di2BPevARZcfNe7QJefvQE0AAHbGcUNbVb0sycOSnFlV1ye5OMnDquqcJC3JoSQ/nCSttfdX1SuT/HmSm5P8WGvtS9tScwAAgD3guKGttXbBitEvOEb5X0ryS+tUCmAznJUDAHajde4eCQAAwDYT2gAAADomtAEAAHRMaAMAAOiY0AYAANCxTf1OG8DJzp0mAYCThTNtAAAAHRPaAAAAOubySIDjcCklALCTnGkDAADomNAGAADQMaENAACgY0IbAABAx9yIBGCLrLphSeKmJQDAepxpAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANCxfTtdAYC96MBFl68cf+iS809wTQCA3jnTBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB1z90iAzrnTJADsbc60AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNuRAKwi6y6aYkblgDAyc2ZNgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiYu0cC7FHuNAkAJwdn2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0bN9OVwCA/h246PKjxh265Py1yq4qtxVlAWC3caYNAACgY0IbAABAx1weCcCu4lJKAHYbZ9oAAAA65kwbAHuWs3IAnAycaQMAAOiY0AYAANCx414eWVWXJfknSW5qrX3LOO7fJfnfk/xtko8meVJr7a+q6kCSa5N8cHz721trT9mOigPAibQdv1U3tywAe9OUM20vSnLe0rg3J/mW1tq3JvlQkp9eeO2jrbVzxofABgAAsIbjhrbW2tuSfGZp3JtaazePT9+e5C7bUDcAAIA9byv+p+0Hkrxx4fndq+rKqvrjqnrIRm+qqgur6oqquuLw4cNbUA0AAIDdZ63QVlU/m+TmJP9xHHVDkrNaa/dP8i+TvLSqbrfqva21S1trB1trB/fv379ONQAAAHatTYe2qnpihhuU/B+ttZYkrbW/aa19ehx+d4ablNxrC+oJAACwJ20qtFXVeUl+MskjW2tfWBi/v6pOGYfvkeTsJNdtRUUBAAD2oim3/H9ZkoclObOqrk9ycYa7Rd4myZurKvnKrf3/QZJfqKovJvlykqe01j6zcsIAAAAc13FDW2vtghWjX7BB2VcnefW6lQIAAGCwFXePBAAAYJsIbQAAAB077uWRAEAfDlx0+VHjDl1y/g7UBIATyZk2AACAjgltAAAAHXN5JADsMqsuo0xWX0rZa1mXfQJ8hTNtAAAAHRPaAAAAOubySADgpOFSSmAvcqYNAACgY0IbAABAx4Q2AACAjgltAAAAHXMjEgBgV3LTEmC3cKYNAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DE/rg0A7Hl+iBvomTNtAAAAHXOmDQBgolVn5JLVZ+XmlAU4FmfaAAAAOia0AQAAdMzlkQAAO2zdyy5dcgm7mzNtAAAAHRPaAAAAOubySACAXWrOpZTbVRZYnzNtAAAAHRPaAAAAOubySAAAto1LKWF9zrQBAAB0TGgDAADomMsjAQDYcXN+YBz2GmfaAAAAOia0AQAAdExoAwAA6JjQBgAA0DE3IgEA4KTipiXsNc60AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMfcPRIAgF3LnSbZDZxpAwAA6JgzbQAAkNVn5ZyRowfOtAEAAHRMaAMAAOiYyyMBAGAml1JyIjnTBgAA0DGhDQAAoGMujwQAgG3kUkrW5UwbAABAx4Q2AACAjrk8EgAAOuFSSlZxpg0AAKBjk0JbVV1WVTdV1TUL4+5QVW+uqg+Pf08fx1dV/VpVfaSqrqqqb9uuygMAAOx2U8+0vSjJeUvjLkryltba2UneMj5PkkckOXt8XJjkuetXEwAAYG+aFNpaa29L8pml0Y9K8uJx+MVJHr0w/iVt8PYkt6+qO29BXQEAAPacdf6n7Y6ttRvG4U8mueM4/I1JPr5Q7vpx3C1U1YVVdUVVXXH48OE1qgEAALB7bcmNSFprLUmb+Z5LW2sHW2sH9+/fvxXVAAAA2HXWCW03Hrnscfx70zj+E0nuulDuLuM4AAAAZlontL0+yRPG4Scked3C+O8f7yL5wCSfXbiMEgAAgBkm/bh2Vb0sycOSnFlV1ye5OMklSV5ZVT+Y5GNJvncs/oYk353kI0m+kORJW1xnAACAPWNSaGutXbDBSw9fUbYl+bF1KgUAAMBgUmgDAAD6ceCiy1eOP3TJ+Se4JpwIW3L3SAAAALaH0AYAANAxl0cCAMAuNudSSpdd9smZNgAAgI4JbQAAAB0T2gAAADomtAEAAHTMjUgAAIDZ3LTkxHGmDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOrZvpysAAADsbgcuuvyocYcuOX8HanJycqYNAACgY0IbAABAx1weCQAAdMOllEdzpg0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx/btdAUAAAA248BFlx817tAl5+9ATbaXM20AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMf2bfaNVXXvJK9YGHWPJP8qye2T/FCSw+P4n2mtvWGz8wEAANjLNh3aWmsfTHJOklTVKUk+keQ1SZ6U5DmttV/ZigoCAADsZVt1eeTDk3y0tfaxLZoeAAAA2brQ9rgkL1t4/tSquqqqLquq01e9oaourKorquqKw4cPryoCAACw560d2qrqq5I8MsmrxlHPTXLPDJdO3pDkWave11q7tLV2sLV2cP/+/etWAwAAYFfaijNtj0jyntbajUnSWruxtfal1tqXkzwvyblbMA8AAIA9aStC2wVZuDSyqu688NpjklyzBfMAAADYkzZ998gkqaqvTfKPkvzwwuhnVtU5SVqSQ0uvAQAAMMNaoa219tdJzlga9/i1agQAAMD/slV3jwQAAGAbCG0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAju3b6QoAAABstwMXXX7UuEOXnL8DNZnPmTYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomLtHAgAAjFbdZTLZ2TtNOtMGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICO7Vt3AlV1KMnnk3wpyc2ttYNVdYckr0hyIMmhJN/bWvvLdecFAACw12zVmbZ/2Fo7p7V2cHx+UZK3tNbOTvKW8TkAAAAzbdflkY9K8uJx+MVJHr1N8wEAANjVtiK0tSRvqqp3V9WF47g7ttZuGIc/meSOWzAfAACAPWft/2lL8uDW2ieq6uuTvLmqPrD4YmutVVVbftMY8C5MkrPOOmsLqgEAALD7rH2mrbX2ifHvTUlek+TcJDdW1Z2TZPx704r3XdpaO9haO7h///51qwEAALArrRXaquprq+rrjgwn+a4k1yR5fZInjMWekOR168wHAABgr1r38sg7JnlNVR2Z1ktba79fVe9K8sqq+sEkH0vyvWvOBwAAYE9aK7S11q5Lcr8V4z+d5OHrTBsAAIDtu+U/AAAAW0BoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY/t2ugIAAAAnowMXXb5y/KFLzt/S+TjTBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0bN9OVwAAAGC3O3DR5SvHH7rk/OO+15k2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAxzYd2qrqrlX11qr686p6f1X9+Dj+56vqE1X13vHx3VtXXQAAgL1l3xrvvTnJM1pr76mqr0vy7qp68/jac1prv7J+9QAAAPa2TYe21toNSW4Yhz9fVdcm+catqhgAAABb9D9tVXUgyf2TvGMc9dSquqqqLquq0zd4z4VVdUVVXXH48OGtqAYAAMCus3Zoq6rbJnl1kn/RWvtckucmuWeSczKciXvWqve11i5trR1srR3cv3//utUAAADYldYKbVV16wyB7T+21n43SVprN7bWvtRa+3KS5yU5d/1qAgAA7E3r3D2ykrwgybWttWcvjL/zQrHHJLlm89UDAADY29a5e+SDkjw+ydVV9d5x3M8kuaCqzknSkhxK8sNrzAMAAGBPW+fukX+SpFa89IbNVwcAAIBFW3L3SAAAALaH0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx4Q2AACAjgltAAAAHRPaAAAAOia0AQAAdExoAwAA6JjQBgAA0DGhDQAAoGNCGwAAQMeENgAAgI4JbQAAAB0T2gAAADomtAEAAHRMaAMAAOiY0AYAANAxoQ0AAKBjQhsAAEDHhDYAAICOCW0AAAAdE9oAAAA6JrQBAAB0TGgDAADomNAGAADQMaENAACgY0IbAABAx7YttFXVeVX1war6SFVdtF3zAQAA2M22JbRV1SlJfiPJI5LcN8kFVXXf7ZgXAADAbrZdZ9rOTfKR1tp1rbW/TfLyJI/apnkBAADsWtVa2/qJVj02yXmttSePzx+f5AGttaculLkwyYXj03sn+eCKSZ2Z5FMTZ7sdZXd6/idb2Z2e/24uu9Pz381ld3r+J1vZnZ7/bi670/PfzWV3ev4nW9mdnv9uLrvT89/NZXd6/ltR9m6ttf0rS7fWtvyR5LFJnr/w/PFJfn0T07liJ8vu9PxPtrI7Pf/dXHan57+by+70/E+2sjs9/91cdqfnv5vL7vT8T7ayOz3/3Vx2p+e/m8vu9Py3s2xrbdsuj/xEkrsuPL/LOA4AAIAZtiu0vSvJ2VV196r6qiSPS/L6bZoXAADArrVvOybaWru5qp6a5A+SnJLkstba+zcxqUt3uOxOz/9kK7vT89/NZXd6/ru57E7P/2Qru9Pz381ld3r+u7nsTs//ZCu70/PfzWV3ev67uexOz387y27PjUgAAADYGtv249oAAACsT2gDAADomNAGAADQsW25EUkvquoBSa5trX2uqr4myUVJvi3Jnyf55dbaZ3e0glusqr6+tXbTFk7vwUnOTXJNa+1NWzVdjq+q7pPkUUm+cRz1iSSvb61du8Y0j9zJ9b+31v6wqr4vyd9Pcm2SS1trX1wo+/Qkr2mtfXyz82Pvqap7JPmeDD/58qUkH0ry0tba5ya8d/L2a6u3deM0X9Ja+/6tnOaEed4nQx9/R2vtfyyMP6+19vsnsi4L897yZQvsbVV1bpLWWntXVd03yXlJPtBae8MOV+2kctKdaauqg1X11qr67aq6a1W9uao+W1Xvqqr7LxW/LMkXxuFfTXJakv93HPfCE1TfO1XVc6vqN6rqjKr6+aq6uqpeWVV3nvD+MzYYf4elxxlJ3llVp1fVHTY5zXcuDP9Qkl9P8nVJLq6qi45X1x5U1YdmlF25HGa8f/K6rarTquqSqvpAVX2mqj5dVdeO426/VPankrw8SSV55/ioJC9bXg9V9a0Lw7euqp+rqtdX1S9X1alLVX5hkvOT/HhV/VaSf5bkHUn+XpLnL5X9N0neUVX/pap+tKr2b24pbU5V/W5V/Z9VddsJZe9RVZdV1S9W1W2r6nlVdU1VvaqqDpyA6qaqnlpVZ47D31RVb6uqv6qqd1TV311jupPX77rbmg3mf6uq+oGquryq3ldV76mql1fVw1aUfXqS/5DkqzO0qdtkCG9vXy4/Z/u1zrbuGJ/r9UuP30vyPUeeL5XdrnX79CSvS/K0JNdU1aMWXv7lNaZ7SlX9cFX9m6p60NJrP7f0fM562DdO9/er6qrx8caqekpV3Xqz9T3G57h06fmpVfWTVfV/V9VXV9UTx/X1zOXtxMyyc9r4nOOPbVET9yVz2sEW1OnrZ5Sdvd+tqifNfc/Ceze9f6gZxxNbYU77mtMfa97+dM6xytS2eHGSX0vy3Kr6txmOLb82yUVV9bObXmDH/hyz7sp4ItRW7KPn/BJ3D48MB7CPSHJBko8neew4/uFJ/myp7LULw+9Zeu29S8/PWxg+LckLklyV5KVJ7rhGfX8/w075onF6P5XhQOZpSV63VPaSJGeOwweTXJfkI0k+luShS2W/nOQvlh5fHP9et8lpXrkw/K4k+8fhr01y9VLZ2yX5t0l+K8n3Lb32mxOWyxlb0BY+n+Rz4+Pz4+NLR8ZvdtkeZ55v3OS6/YPx9TstjLvTOO5NS2U/lOTWK+b9VUk+vDTuPQvDz0ryoiQPTfKcJC9ZKnvV+HdfkhuTnDI+ryOvLbaFDF/qfNfYFw6Pn/cJSb5uu/tOhjOLv5PkM0lemeQxSb5qg7JvS/Ij43q4JskzxvXwg0n+aDPrdnx+2yS/kOT9ST47LoO3J3niive+f2H48iSPGYcfluRPl8reKclzk/xGkjOS/HySq8fPeec11u+k9pjkPUl+Lsk9JyyTF471e3CSfz8uj3+U5A+TPG2p7NULberUJP95HD4rC9uWcdyk7dcmyk76bGO53x7Xz0PHvzeMww9dKjt53c5sX1cnue04fCDJFUl+/Ej/W2O6z8/Q9/5Fkncnefaq9rSJZfuysd0+MMldxscDx3Gv2EzfSXKHDR5nJLl+qewrM/SB30zylgwHfg9J8u+S/NYaZee08TnHH3P62eRtaCbuS+a0gzmPDdbVoSSnJ7nDUtmt2u/+tzWW16T9Q+YdT8zZN0w+VprZvub0xzn70znHKlPb4tUZfv7r1HH53m4c/zU5+thjzvKas/04mOStGbb7d03y5nHdvSvJ/ZfKnja23Q+My+zTGa5IuiTJ7deo7+Rjxg37wmY77lY/MrET5pbBYrkjX7n0/FVJnjQOvzDJwXH4XknetVR28eDo+Ul+McndkvxEkteu0WGPVd/3Lj2/emH4rUn+3kJ9r1gq+4yxAfzdhXF/sWL+c6b5vgwb3jNWvLa8bF89NuBHZ/jh9Fcnuc3yshyfzwmOc3Z0v5bkJUvt46hlsInl8G0bPL49yQ2bXLcfPMbn+ODS8w8kuduKcndbUXaxDu/NGPayOohdkyH4nZ5hR3SHcfxXZ+ELjg3W4a2TPDLDjuLwZvvOcdbnG5c/V4YN4uOTvCFDP3thku86xjI43jZh0rody74uyRMz7Az/ZZL/J8nZSV6c4fLqleswR29bltfDnLA/Z/1Oao8ZDsh/Jcl/y3CQ8BNJvmGDdbI8j7ePf2+zos1cna9sA07PQr/KcIn17O3XJspO+mwZvpD4iQw77nPGcddtMM0563ZO+3r/0vPbjp/z2Tl6+zFnulctDO/L8DtAvzuus+X+MGfZfmjV+FWvZWLfyXBQfF1uGRqPPP/bVW14bPufzFd+smhVX5hTdk4bP1YfW162c/rZnOOPSfuSOe1gLDPp4DPzgv6c/e5VGzyuTvI3ayyvSess844n5uwb5hwrzWlfc/rjlQvr+Hj70znHKlPb4pWrhsfny9u6OctrzvZjTiCeE1y3av3eYjlsuFynFDoRj6mdMMmfZfj2/59lOOh/9Dj+oTl6I3Bahm+mP5rhMrAvjiv0j5Pc7xjzX25Ey8/ndNj3LQz/4tJryzuLa5PsG4ffvvTaLc50jePukiGYPjvDZYxHHXTMmWaGb8uONPjrMn7rn+Fg4njL5GeT/GmGwLfcUOdsuCfv6Mby357kj5I8PcOB2EYHXnOWw5fGab51xeN/bnLdvinJT+aWO4Q7ZtgI/OFS2fMyhNo3ZtjRXprhwOojWfhyYyx7XYZvzf5pjj7AeN/S858Yy39sXF5vSfK8DDvFi5fKXnmMZX7qGn1naiA+6pvgsW09JUtnzzJ8i3yvDJfkfSpf+XLmm1ash0nrdoPl967x760yXIu/+NovZdjW3CPJz2T4dvtuSZ6U5D9ttGxz/LB/XYb/EZuyfie1x6X19ZAMZyM+OS6DC1cs23surLu3Lbz250tlfzzDQdbzMnzxcOQLs/2L71sof9zt19yycz7b0nR/fXldbHLdzmlff5QxMC6M25fhwPFLa0z3Ays+w8UZts8fPsYyON6yfXuG/e6tFsbdKsk/z/A/ebP7TpIPJzlrg/l9fKO+keSy48xvTtk5bXzO8cecfjZnGzppX7KJdjDp4DPzgv6c/e6NSc7J0LcWHwcy/B/2ZpfXkf3DuTn+/mHq8cScfcNyfY51rDSnfc3pj3P2p3OOVaa2xXdkPG5Yqu9pK5bBnOU1Z/tx5cLw8QLxnOA6p77H2kcfdYy/cv5TCp2Ix9ROmOR+GVLwG5PcJ8P/qv1VhrNef3+Dad9ufN+3Z4PLtZJcnyGAPSPDgVItvLbcsed02F/IeAnM0vhvSvI7S+OeNnaC78hwucavjp31X2fpko6l9z1y7MCfXPHapqa5NI1Tk9x9ady1i51vHPfEcT18bEXZqRvuWQddC8v96Un+S5Y27ptZDhnOSJ29wXQ+vjA8Z92enuH/KT+Q5C8znHK/dhx3hw0+0wMzHKz/03H4lBXlXpTh27IjjzuO4++U5C0ryn9DxhCc5PZJHpvk3BXl7jWj787pO1MD8VEH+ceY/8OTfHBcng/OcLDx4SQ3ZdzpzV234/P/muTBC33sDxZeO2qjnuEg/h0ZDgw+n/GGR0lOWyo3ecO9tG6PuX6ntses3oGfkuHLghcujf+ODF+gfDjDFyoPGMfvT/LMFdP55rFN3WfG+ttw+zW3bFafQVj52ZbKnJ+lL9yWXn/ixHU7p33dJQvf5C699qA1pvvbWfpyZxz/5CRfXGPZHkjyirFffWh83DSOW94/TOo7SX4sS1+gLry2fGni8zdo3/dM8idrlD3Sxj8ytvEHbtTGM4SK5eOPv8yw31teZ3P62Zxt6KR9ydx2kHkHn1OD/pz97guOtJkV03npGsvrWPuHR62Y15Tjicn7hsw7VrrfjPZ1IEPfO5yhLx75TKv645z96eRjlRlt8TYbzOvMLIT/TSyvOduPOYF4TnCdU9/Jx4wbrp+pK3K7HzM74d8ZO+Jtl8YftYGaMf+Llx5H/p/rTjn6f0fmdNgH5JbX7/7rJL83NurTVtTjYWOnuzLDGZA3JLkwS//jtDTdU5M8M8M1+EdNd+o0Zy6vZyb5zhXjz8vR/3c1Z8M9eUc3vnZuvnLm7iFJ/lWS796gzv9w4rJ9bJJ7bzCNRx9jmTx4bMPftcHr90nynVvZbjeYz0u2cnoT5jen70w++FzncyX5T1nakM5dtxl2oO/MsDP6k4xBNsPB3NMn1GHllyKZueEe+/qRNn7fsY0d1cYzHGjcdUK9Xj5z/f5vU+a/iXZzn4zb8Qzbxm8Zxx+zP4z9/OdW9bO5n21GXRe3M9+cYT+1ah1satsxYf5rTXejvrPcZhbXwwblHzAuizOSPCjJ/7XBcljuO/cexx/Vd5aW7az2deRzZeF44TjTPX+DspXxEv5jLa8N2uIz1m2LmbcNndTP5y7bzDj4XHj9uF+4ZHuOPyYvr4V2u9h/V7bbpffcOcmnN3jtWzNx35B5x0oPyHj8NvbFX8iwL1t5vLjwvjPGx2+v2xaW6nDq1Docrz/MWLeTl9cG799oW3e/TDzhk1uG0c/klmH09HXqm4X93nL5KcvnyHXeO268u8yi32ytHa6qO2X4tuv7x3JPT/KjGRbmORn+cft142vvaa192ybnP/kW5zXc0e35GS6JfH+SH2itfaiGO+xd0Fr7tYWy78/wTcDN491svpDhH0IfPo7/nqVpT7oF9Irp/nWGb5BuMd265c8enJqv/OzB+7OFP3tQx/l5gPFOXD+S4TKFfRmuKX5thktXbl4o9/LW2uMmzvPiDNco78vwvynnJvnPGf6J/A9aa7+0UHbTt7Df6LNV1Ttba+eOwz+U4Vuf12T4Nuf3WmuXLM3/xzJ0/HOyde329StGf0eGs1lprT1yM9PdRD2mttvHZjib9MEV03h0a+214/Dy56oMofuozzVnGdTRPwPy00nunwk/A3KsNr5V66GqntRae+HC8zlt/LMZtgMfzfC/h69qrR3eYD5T19fk+c8xpz+s6Gc/mmHbsaqfbXr9HqOuJ2Q7M6EeU9fZnL6z3GZe2Vr71Abzn9UWqurvjPV9+0b1XTHNB2Q4675q2c75XOtMN9l4+zG5LS47zvbjARmu0vnswn56ZbtdWmcvzdDPj1pnc5bBWP6ZGf5f5w+Xxp+X5P9rrZ29wef6mgyXl16zvP1aUfYh4zK4etVxwmYdZ9lOareb2I8srq+fygY/J7VU9pjbpKnHdZuo75z+MKcOi/3hyRm26a/NhP5wLMfY1j2itfbGYyyDDbcJx5nfMdvtscrOzA5PS/LUrHMcOCXZ7fQj4/9GjMNbcsetFfP4bJL/nuGU+I9m/OZmnbqOz+fcwfLpGU7jvzbD/5c96hjvnTTdDOHsyKWJl2a469yDM3w79btrLK93Lgz/UIabJFyc4XKKiza7vGYu2zl3JFpcvz+ShW9VN/vZkll329yudntlJt4Fb7seGc6kTmq3U9fvnM81s+xyf/j3G/WHpXbw5GO18cy4G+FxlsHytfZz2viVmXDHzznra878Z7aZyf0h8/rZ5PU7s66b2c5sej+yog5z9g1z+85x28wmlsPTM3yxesz6zpzmnM+1XdOd0xYn7yMzb7s0tZ9vWd/NxP10jt5+LW9Dr9xoGcysz5xt86TlkJl3lJ2xvuaUnXO8OKe+c/rDnDpM7g8z1u2c/dOWHP8st9uZbXzyNj9bcBy4qQ5zoh+LCykz7rg1cx6Td14zV+icO1jOOZCZNN05HXDu8loY3nRnXbOzXLlqeNVnm7N+p362zLvb5na121tl4l3wtuuxFRui5fU753PNLLstO6SZdZhzl7Q5bXzSHT/nrK8585/ZZubcOXFOP9vy7d3MdbAl+5EVdZizzua0xTl3iZ2zHCbVd+Y053yu7ZrunLa4WIfjbT9mHaxPWWdzlsGE9re4bd7s9mtLDuo3sWwnLYeZ7WDO+ppTds7x4nb1hzl1mNwfZqzb7drWzWm3s9p4ph9brn0cuC+dqKqrNnopwz8AHnFjVZ3TWntvkrTW/kdV/ZMMP6S96R86HSbVvpzhf6/eVMOPEx65PeivZLhWeW5dk+FboF+t4ccsP5Xkz6rq4xkuD3zyUtlbtfFUcGvt0HhJ4e9U1d3GaW9muouXLLyvqg621q6oqntluJvmZt2qqk7P0FirjZdhtdb+uqpuXiw4Z3nNXLZ/W1Wntta+kOEmM0emcVqG2xIvmrx+Z3y20zLcmaqStKq6c2vthhp+wHJ5fW1Lux0/03Oq6lXj3xuTE96vJ7fbqet3zueauQzm9IfJbXxmHe6Y5B9n+H+I5WXwX5fGzWnjt1jWrbUvZrgL3Ovrlj/EPWc7M2f+c8zpD3P62XZs77ZrOzPH5HU2sy1ObTPJvOUwtb6Tpznzc23XdOe0xcnbj8xrt1PX2ay+O2PfO2f7NWcZzDFnupOWwzbuR+aUnXy8uF39YU4dMq8/TLVd27o57XZO2Tnb/PWPAzeThLfjkYm3e82MO27NnP+Vx3ht+Rbnk29Nu/CeKXewnHwL6KnTzYyfPZi5vA5l+s8DzLmV75yyc+5INGf9Tv5sG00vR9+9aVva7YppHfMueNvxmNNuN9N35n6uY5Wd0x/WaQfHqcOcu6TNaeOT7vg5c31Nnv/MNrN2f9ign2359m67tjMz6zB73zCxLc65S+yc5TCpvuu0r+N8rm2Z7sy2OHn7MXO7NLWfz1oGmX4MNmf7NXkZzFzec5btptrCcdrXnPU1e5uUCceL290fNlOHY/WHGe/drm3dnHY7p+yVx1oOS8/X3u/1dCOSF2S4O+CfrHjtpa2179vm+d+rtfahiWW3pa5VdZckN7fWPrnitQe11v50M9Md33+7JHfP0Pivb63duNlpHWc+p2bo4H+xMG7y8trGZTt5/R5jGkd9Nua1253u5wvz2nR/ONnbwXZuZ3pxorZ3K+a79nZmg+meVOvsZKvviXSs7cdOtdtx3ids27xd29ATvW2es752ct2eTE62bcd2bfM3nF8voQ0AAICj3WqnKwAAAMDGhDYAAICOCW0AAAAdE9oAAAA69v8DT9AM2IoshGwAAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"raw_dataset_train_df.label.value_counts().plot(kind='bar', figsize=(15, 10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is the list of classes with their definition:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| label | intent (category) |\n",
"|---:|:-------------------------------------------------|\n",
"| 0 | activate_my_card |\n",
"| 1 | age_limit |\n",
"| 2 | apple_pay_or_google_pay |\n",
"| 3 | atm_support |\n",
"| 4 | automatic_top_up |\n",
"| 5 | balance_not_updated_after_bank_transfer |\n",
"| 6 | balance_not_updated_after_cheque_or_cash_deposit |\n",
"| 7 | beneficiary_not_allowed |\n",
"| 8 | cancel_transfer |\n",
"| 9 | card_about_to_expire |\n",
"| 10 | card_acceptance |\n",
"| 11 | card_arrival |\n",
"| 12 | card_delivery_estimate |\n",
"| 13 | card_linking |\n",
"| 14 | card_not_working |\n",
"| 15 | card_payment_fee_charged |\n",
"| 16 | card_payment_not_recognised |\n",
"| 17 | card_payment_wrong_exchange_rate |\n",
"| 18 | card_swallowed |\n",
"| 19 | cash_withdrawal_charge |\n",
"| 20 | cash_withdrawal_not_recognised |\n",
"| 21 | change_pin |\n",
"| 22 | compromised_card |\n",
"| 23 | contactless_not_working |\n",
"| 24 | country_support |\n",
"| 25 | declined_card_payment |\n",
"| 26 | declined_cash_withdrawal |\n",
"| 27 | declined_transfer |\n",
"| 28 | direct_debit_payment_not_recognised |\n",
"| 29 | disposable_card_limits |\n",
"| 30 | edit_personal_details |\n",
"| 31 | exchange_charge |\n",
"| 32 | exchange_rate |\n",
"| 33 | exchange_via_app |\n",
"| 34 | extra_charge_on_statement |\n",
"| 35 | failed_transfer |\n",
"| 36 | fiat_currency_support |\n",
"| 37 | get_disposable_virtual_card |\n",
"| 38 | get_physical_card |\n",
"| 39 | getting_spare_card |\n",
"| 40 | getting_virtual_card |\n",
"| 41 | lost_or_stolen_card |\n",
"| 42 | lost_or_stolen_phone |\n",
"| 43 | order_physical_card |\n",
"| 44 | passcode_forgotten |\n",
"| 45 | pending_card_payment |\n",
"| 46 | pending_cash_withdrawal |\n",
"| 47 | pending_top_up |\n",
"| 48 | pending_transfer |\n",
"| 49 | pin_blocked |\n",
"| 50 | receiving_money |\n",
"| 51 | Refund_not_showing_up |\n",
"| 52 | request_refund |\n",
"| 53 | reverted_card_payment? |\n",
"| 54 | supported_cards_and_currencies |\n",
"| 55 | terminate_account |\n",
"| 56 | top_up_by_bank_transfer_charge |\n",
"| 57 | top_up_by_card_charge |\n",
"| 58 | top_up_by_cash_or_cheque |\n",
"| 59 | top_up_failed |\n",
"| 60 | top_up_limits |\n",
"| 61 | top_up_reverted |\n",
"| 62 | topping_up_by_card |\n",
"| 63 | transaction_charged_twice |\n",
"| 64 | transfer_fee_charged |\n",
"| 65 | transfer_into_account |\n",
"| 66 | transfer_not_received_by_recipient |\n",
"| 67 | transfer_timing |\n",
"| 68 | unable_to_verify_identity |\n",
"| 69 | verify_my_identity |\n",
"| 70 | verify_source_of_funds |\n",
"| 71 | verify_top_up |\n",
"| 72 | virtual_card_not_working |\n",
"| 73 | visa_or_mastercard |\n",
"| 74 | why_verify_identity |\n",
"| 75 | wrong_amount_of_cash_received |\n",
"| 76 | wrong_exchange_rate_for_cash_withdrawal |\n",
"\n",
"And this is a summary of the dataset\n",
"\n",
"| Dataset statistics | Train | Test |\n",
"| --- | --- | --- |\n",
"| Number of examples | 10 003 | 3 080 |\n",
"| Average character length | 59.5 | 54.2 |\n",
"| Number of intents | 77 | 77 |"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"bert-base-uncased\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"\n",
"def tokenize_function(example):\n",
" return tokenizer(example[\"text\"], truncation=True)\n",
"\n",
"\n",
"tokenized_datasets_train = raw_dataset_train.map(tokenize_function, batched=True)\n",
"tokenized_datasets_val = raw_dataset_val.map(tokenize_function, batched=True)\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\"intent-banking\", per_device_train_batch_size=16)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=77) # watch out for the number of labels!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model,\n",
" training_args,\n",
" train_dataset=tokenized_datasets_train,\n",
" eval_dataset=tokenized_datasets_val,\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fcb99442c8174f91a78ddaf8ec9ba4f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1878 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Checkpoint destination directory test-trainer/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 1.4679, 'grad_norm': 9.389148712158203, 'learning_rate': 3.668796592119276e-05, 'epoch': 0.8}\n",
"{'loss': 0.4518, 'grad_norm': 7.393991947174072, 'learning_rate': 2.3375931842385517e-05, 'epoch': 1.6}\n",
"{'loss': 0.2117, 'grad_norm': 5.330203056335449, 'learning_rate': 1.0063897763578276e-05, 'epoch': 2.4}\n",
"{'train_runtime': 582.5736, 'train_samples_per_second': 51.511, 'train_steps_per_second': 3.224, 'train_loss': 0.5940346174473706, 'epoch': 3.0}\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=1878, training_loss=0.5940346174473706, metrics={'train_runtime': 582.5736, 'train_samples_per_second': 51.511, 'train_steps_per_second': 3.224, 'train_loss': 0.5940346174473706, 'epoch': 3.0})"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ce36c1632c1b4ccfbcccd9b165b3c844",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/385 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(3080, 77) (3080,)\n"
]
}
],
"source": [
"predictions = trainer.predict(tokenized_datasets_val)\n",
"print(predictions.predictions.shape, predictions.label_ids.shape)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([41, 11, 11, ..., 24, 24, 24])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions.predictions.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_val_df = raw_dataset_val.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"raw_dataset_val_df['prediction'] = predictions.predictions.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.97 0.99 40\n",
" 1 0.98 1.00 0.99 40\n",
" 2 1.00 1.00 1.00 40\n",
" 3 1.00 0.97 0.99 40\n",
" 4 0.95 0.93 0.94 40\n",
" 5 0.84 0.80 0.82 40\n",
" 6 1.00 0.93 0.96 40\n",
" 7 0.97 0.93 0.95 40\n",
" 8 1.00 0.95 0.97 40\n",
" 9 0.95 1.00 0.98 40\n",
" 10 0.93 0.93 0.93 40\n",
" 11 0.88 0.88 0.88 40\n",
" 12 0.90 0.88 0.89 40\n",
" 13 0.97 0.95 0.96 40\n",
" 14 0.87 0.97 0.92 40\n",
" 15 0.84 0.93 0.88 40\n",
" 16 0.90 0.95 0.93 40\n",
" 17 0.88 0.95 0.92 40\n",
" 18 1.00 0.95 0.97 40\n",
" 19 0.95 0.95 0.95 40\n",
" 20 0.83 0.97 0.90 40\n",
" 21 0.89 1.00 0.94 40\n",
" 22 0.94 0.78 0.85 40\n",
" 23 1.00 0.88 0.93 40\n",
" 24 0.93 0.95 0.94 40\n",
" 25 0.85 0.97 0.91 40\n",
" 26 0.78 1.00 0.88 40\n",
" 27 1.00 0.75 0.86 40\n",
" 28 0.90 0.88 0.89 40\n",
" 29 0.97 0.90 0.94 40\n",
" 30 1.00 1.00 1.00 40\n",
" 31 0.97 0.93 0.95 40\n",
" 32 0.93 0.97 0.95 40\n",
" 33 0.88 0.95 0.92 40\n",
" 34 1.00 0.95 0.97 40\n",
" 35 0.87 0.97 0.92 40\n",
" 36 0.95 0.90 0.92 40\n",
" 37 0.94 0.85 0.89 40\n",
" 38 0.97 0.97 0.97 40\n",
" 39 0.93 0.97 0.95 40\n",
" 40 0.87 0.97 0.92 40\n",
" 41 0.86 0.95 0.90 40\n",
" 42 1.00 0.97 0.99 40\n",
" 43 0.90 0.95 0.93 40\n",
" 44 1.00 1.00 1.00 40\n",
" 45 0.95 0.95 0.95 40\n",
" 46 1.00 0.97 0.99 40\n",
" 47 0.91 0.97 0.94 40\n",
" 48 0.89 0.80 0.84 40\n",
" 49 0.97 0.85 0.91 40\n",
" 50 0.95 0.93 0.94 40\n",
" 51 1.00 1.00 1.00 40\n",
" 52 1.00 0.97 0.99 40\n",
" 53 0.93 0.93 0.93 40\n",
" 54 0.89 0.97 0.93 40\n",
" 55 0.98 1.00 0.99 40\n",
" 56 0.92 0.90 0.91 40\n",
" 57 0.93 0.95 0.94 40\n",
" 58 0.95 0.95 0.95 40\n",
" 59 0.92 0.90 0.91 40\n",
" 60 1.00 0.97 0.99 40\n",
" 61 0.92 0.85 0.88 40\n",
" 62 0.86 0.80 0.83 40\n",
" 63 0.93 1.00 0.96 40\n",
" 64 0.95 0.93 0.94 40\n",
" 65 0.95 0.88 0.91 40\n",
" 66 0.92 0.90 0.91 40\n",
" 67 0.78 0.95 0.85 40\n",
" 68 0.95 0.93 0.94 40\n",
" 69 0.81 0.95 0.87 40\n",
" 70 1.00 1.00 1.00 40\n",
" 71 1.00 1.00 1.00 40\n",
" 72 1.00 0.93 0.96 40\n",
" 73 1.00 0.93 0.96 40\n",
" 74 0.91 0.80 0.85 40\n",
" 75 1.00 0.90 0.95 40\n",
" 76 0.95 0.88 0.91 40\n",
"\n",
" accuracy 0.93 3080\n",
" macro avg 0.94 0.93 0.93 3080\n",
"weighted avg 0.94 0.93 0.93 3080\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(raw_dataset_val_df['label'], raw_dataset_val_df['prediction']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We achieve around 93% accuracy and 93% F1-score... Not so bad for a classification problem with 77 classes!!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparing to scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.90 0.95 40\n",
" 1 0.93 0.97 0.95 40\n",
" 2 0.97 0.97 0.97 40\n",
" 3 0.90 0.93 0.91 40\n",
" 4 0.97 0.90 0.94 40\n",
" 5 0.64 0.75 0.69 40\n",
" 6 0.85 0.88 0.86 40\n",
" 7 0.89 0.82 0.86 40\n",
" 8 0.89 0.97 0.93 40\n",
" 9 1.00 0.97 0.99 40\n",
" 10 0.84 0.53 0.65 40\n",
" 11 0.75 0.90 0.82 40\n",
" 12 0.75 0.82 0.79 40\n",
" 13 0.86 0.93 0.89 40\n",
" 14 0.63 0.82 0.72 40\n",
" 15 0.78 0.90 0.84 40\n",
" 16 0.64 0.72 0.68 40\n",
" 17 0.90 0.90 0.90 40\n",
" 18 1.00 0.88 0.93 40\n",
" 19 0.90 0.88 0.89 40\n",
" 20 0.66 0.78 0.71 40\n",
" 21 0.97 0.93 0.95 40\n",
" 22 0.82 0.80 0.81 40\n",
" 23 1.00 0.82 0.90 40\n",
" 24 0.90 0.90 0.90 40\n",
" 25 0.76 0.80 0.78 40\n",
" 26 0.82 0.80 0.81 40\n",
" 27 0.94 0.75 0.83 40\n",
" 28 0.89 0.78 0.83 40\n",
" 29 0.81 0.72 0.76 40\n",
" 30 0.93 0.95 0.94 40\n",
" 31 0.90 0.88 0.89 40\n",
" 32 0.89 1.00 0.94 40\n",
" 33 0.78 0.90 0.84 40\n",
" 34 0.79 0.85 0.82 40\n",
" 35 0.72 0.85 0.78 40\n",
" 36 0.97 0.70 0.81 40\n",
" 37 0.60 0.75 0.67 40\n",
" 38 0.87 1.00 0.93 40\n",
" 39 0.77 0.68 0.72 40\n",
" 40 0.72 0.97 0.83 40\n",
" 41 0.89 0.80 0.84 40\n",
" 42 0.97 0.93 0.95 40\n",
" 43 0.65 0.80 0.72 40\n",
" 44 1.00 1.00 1.00 40\n",
" 45 0.85 0.82 0.84 40\n",
" 46 0.94 0.78 0.85 40\n",
" 47 0.62 0.90 0.73 40\n",
" 48 0.85 0.57 0.69 40\n",
" 49 1.00 0.85 0.92 40\n",
" 50 0.89 0.80 0.84 40\n",
" 51 0.82 0.93 0.87 40\n",
" 52 0.74 0.78 0.76 40\n",
" 53 0.71 0.90 0.79 40\n",
" 54 0.71 0.90 0.79 40\n",
" 55 0.93 0.95 0.94 40\n",
" 56 0.96 0.62 0.76 40\n",
" 57 0.92 0.88 0.90 40\n",
" 58 0.88 0.72 0.79 40\n",
" 59 0.69 0.78 0.73 40\n",
" 60 0.94 0.80 0.86 40\n",
" 61 0.81 0.75 0.78 40\n",
" 62 0.89 0.60 0.72 40\n",
" 63 0.88 0.93 0.90 40\n",
" 64 0.74 0.93 0.82 40\n",
" 65 0.78 0.88 0.82 40\n",
" 66 0.72 0.78 0.75 40\n",
" 67 0.77 0.75 0.76 40\n",
" 68 0.90 0.68 0.77 40\n",
" 69 0.64 0.62 0.63 40\n",
" 70 0.83 1.00 0.91 40\n",
" 71 0.95 1.00 0.98 40\n",
" 72 0.92 0.28 0.42 40\n",
" 73 1.00 0.93 0.96 40\n",
" 74 0.67 0.72 0.70 40\n",
" 75 0.88 0.90 0.89 40\n",
" 76 0.94 0.75 0.83 40\n",
"\n",
" accuracy 0.83 3080\n",
" macro avg 0.84 0.83 0.83 3080\n",
"weighted avg 0.84 0.83 0.83 3080\n",
"\n"
]
}
],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import classification_report\n",
"\n",
"model = Pipeline([\n",
" (\"vectorizer\", TfidfVectorizer(stop_words=\"english\")),\n",
" (\"classifier\", LogisticRegression())\n",
"])\n",
"\n",
"model.fit(raw_dataset_train['text'], raw_dataset_train['label'])\n",
"\n",
"y_pred = model.predict(raw_dataset_val['text'])\n",
"\n",
"print(classification_report(raw_dataset_val['label'], y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}