{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4AbgGFUcbiVD"
   },
   "source": [
    "# Fine Tuning Bert for CoLA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first we install hub, and huggingface transformers package in this \n",
    "# runtime enviroment\n",
    "\n",
    "!pip install hub\n",
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: Restart the colab runtime as few packages has been updated or you may get error (<font color=\"red\">FileNotFoundError</font>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kahzu-bNaL1Q",
    "outputId": "624b7590-0353-4b76-dcaa-7438c671cfc3"
   },
   "outputs": [],
   "source": [
    "# first we import hub\n",
    "import hub\n",
    "\n",
    "# the following lines will fetch the dataset using hub and show us the \n",
    "# structure in which the data is stored in.\n",
    "ds = hub.Dataset(\"activeloop/CoLA\")\n",
    "print(ds.schema)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EcjcbbyraW58"
   },
   "outputs": [],
   "source": [
    "# since now we have the dataset, let us fetch the BERT model \n",
    "from transformers import BertForSequenceClassification, BertConfig\n",
    "\n",
    "# Initializing a model from the bert-base-uncased style \n",
    "# configuration for sequence classification\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# accessing the model configuration to keep it handy\n",
    "configuration = model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102,
     "referenced_widgets": [
      "3725718e8b5948f4aeea7ac87b0bc510",
      "6d8d5e0818de46499f6693d0f402d36f",
      "4a1f5b8176eb4e0c8ab3c30495d56768",
      "b12031003a394c338b0ac9ac91bf28cb",
      "20d66405ba5c46958563fc2aa7892973",
      "3979e6fadc6543edb27a8ca005e4df22",
      "1256b714bd7c4fa0b2272133b2444781",
      "d3fa352aa9334a6c865aa69c20e8cde7",
      "bf8b4d2452b44fc7a4bff27acbef5acf",
      "538f183f97a946368926343e00140c4f",
      "9c08bb9ad81146d2a7e22a292b5c7c39"
     ]
    },
    "id": "wVo9QUcujIYL",
    "outputId": "5a7cd30f-a686-4243-c6c4-3f6058de81dc"
   },
   "outputs": [],
   "source": [
    "# Now we have both our data and our model. It is time to set up the \n",
    "# input pipeline for which we shall require appropriate tokenizers\n",
    "from transformers import BertTokenizer\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# let us look at how input sequences are generated \n",
    "# for BERT using the simple sentence \"Hello World!\"\n",
    "\n",
    "example = \"Hello World!\"\n",
    "encoded_input = tokenizer(example)\n",
    "print(\"Encoded input tokens: \", encoded_input[\"input_ids\"])\n",
    "decoded_input = tokenizer.decode(encoded_input[\"input_ids\"])\n",
    "print(\"Decoded input sequence: \", decoded_input)\n",
    "\n",
    "# one can observe how the tokenizer converts words into ids mapping them \n",
    "# to the vocab file and adds special sentinel tokens like [CLS] and [SEP]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PLh6t-0yn5fE"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AdamW\n",
    "from tqdm import tqdm\n",
    "\n",
    "def train_model(dataset, model, batch_size=16, epochs=1):\n",
    "\n",
    "    # we set up the device so that we can use  GPU hardware acceleration.\n",
    "    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "    model.to(device)\n",
    "    model.train() \n",
    "\n",
    "    # now we define our optimizer\n",
    "    optim = AdamW(model.parameters(), lr=5e-5)\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        print(\"Epoch: \", epoch+1, flush=True)\n",
    "        avg_loss = 0    #to store the avg_loss across epochs\n",
    "        \n",
    "        # first we fetch batches of raw data\n",
    "        for batch in tqdm(range((ds.shape[0] // batch_size) + 1)):\n",
    "            batch_sentences = []\n",
    "            if batch == (ds.shape[0] // batch_size):\n",
    "                for i in range(batch*batch_size, ds.shape[0]):\n",
    "                    batch_sentences.append(ds[\"sentence\"][i].compute())\n",
    "                batch_labels = ds[\"labels\", batch*batch_size :].compute()\n",
    "            else: \n",
    "                for i in range(batch*batch_size, (batch+1)*batch_size):\n",
    "                    batch_sentences.append(ds[\"sentence\"][i].compute())\n",
    "                batch_labels = ds[\"labels\", batch*batch_size : (batch+1)*batch_size].compute()\n",
    "            \n",
    "            # now we need to preprocess the raw data\n",
    "            batch_sentences = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors=\"pt\")\n",
    "            batch_labels = torch.as_tensor(batch_labels)\n",
    "\n",
    "            # the following steps train the model\n",
    "            optim.zero_grad()\n",
    "            input_ids = batch_sentences['input_ids'].to(device)\n",
    "            attention_mask = batch_sentences['attention_mask'].to(device)\n",
    "            batch_labels = batch_labels.to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=batch_labels)\n",
    "            loss = outputs[0]\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "            avg_loss += loss\n",
    "        print(\"Average Loss:\", avg_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "flwV-wr24AZA",
    "outputId": "31020757-7099-42b5-d226-e765c9fbcf63"
   },
   "outputs": [],
   "source": [
    "train_model(ds, model, batch_size=32, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rdnxEemQ4bxx"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "NLP  using Hub",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "1256b714bd7c4fa0b2272133b2444781": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "20d66405ba5c46958563fc2aa7892973": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9c08bb9ad81146d2a7e22a292b5c7c39",
      "placeholder": "​",
      "style": "IPY_MODEL_538f183f97a946368926343e00140c4f",
      "value": " 232k/232k [00:00&lt;00:00, 2.65MB/s]"
     }
    },
    "3725718e8b5948f4aeea7ac87b0bc510": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4a1f5b8176eb4e0c8ab3c30495d56768",
       "IPY_MODEL_b12031003a394c338b0ac9ac91bf28cb",
       "IPY_MODEL_20d66405ba5c46958563fc2aa7892973"
      ],
      "layout": "IPY_MODEL_6d8d5e0818de46499f6693d0f402d36f"
     }
    },
    "3979e6fadc6543edb27a8ca005e4df22": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4a1f5b8176eb4e0c8ab3c30495d56768": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1256b714bd7c4fa0b2272133b2444781",
      "placeholder": "​",
      "style": "IPY_MODEL_3979e6fadc6543edb27a8ca005e4df22",
      "value": "Downloading: 100%"
     }
    },
    "538f183f97a946368926343e00140c4f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6d8d5e0818de46499f6693d0f402d36f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9c08bb9ad81146d2a7e22a292b5c7c39": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b12031003a394c338b0ac9ac91bf28cb": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bf8b4d2452b44fc7a4bff27acbef5acf",
      "max": 231508,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d3fa352aa9334a6c865aa69c20e8cde7",
      "value": 231508
     }
    },
    "bf8b4d2452b44fc7a4bff27acbef5acf": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d3fa352aa9334a6c865aa69c20e8cde7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
