From 1fd3e0e4161ca07fe66dd9abf3dc947bf448f357 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Thu, 5 Nov 2020 09:24:54 +0530 Subject: [PATCH] Added Example Notebook With reference to this [discussion](https://gitter.im/trax-ml/community?at=5f8d24693d172d78b383d799) --- Deep N-Gram. ipynb | 1194 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1194 insertions(+) create mode 100644 Deep N-Gram. ipynb diff --git a/Deep N-Gram. ipynb b/Deep N-Gram. ipynb new file mode 100644 index 000000000..40e0d9a6b --- /dev/null +++ b/Deep N-Gram. ipynb @@ -0,0 +1,1194 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.024472, + "end_time": "2020-10-19T05:23:45.163806", + "exception": false, + "start_time": "2020-10-19T05:23:45.139334", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Downloading the Trax Package" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.024546, + "end_time": "2020-10-19T05:23:45.211638", + "exception": false, + "start_time": "2020-10-19T05:23:45.187092", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_kg_hide-input": false, + "_kg_hide-output": true, + "execution": { + "iopub.execute_input": "2020-10-19T05:23:45.265606Z", + "iopub.status.busy": "2020-10-19T05:23:45.264326Z", + "iopub.status.idle": "2020-10-19T05:24:40.876515Z", + "shell.execute_reply": "2020-10-19T05:24:40.877287Z" + }, + "papermill": { + "duration": 55.642763, + "end_time": "2020-10-19T05:24:40.877583", + "exception": false, + "start_time": "2020-10-19T05:23:45.234820", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting trax\r\n", + " Downloading trax-1.3.5-py2.py3-none-any.whl (416 kB)\r\n", + "\u001b[K |████████████████████████████████| 416 kB 189 kB/s \r\n", + "\u001b[?25hCollecting gin-config\r\n", + " Downloading gin_config-0.3.0-py3-none-any.whl (44 kB)\r\n", + "\u001b[K |████████████████████████████████| 44 kB 1.2 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from trax) (1.14.0)\r\n", + "Collecting tensor2tensor\r\n", + " Downloading tensor2tensor-1.15.7-py2.py3-none-any.whl (1.4 MB)\r\n", + "\u001b[K |████████████████████████████████| 1.4 MB 4.6 MB/s \r\n", + "\u001b[?25hCollecting funcsigs\r\n", + " Downloading funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)\r\n", + "Requirement already satisfied: gym in /opt/conda/lib/python3.7/site-packages (from trax) (0.17.3)\r\n", + "Requirement already satisfied: absl-py in /opt/conda/lib/python3.7/site-packages (from trax) (0.10.0)\r\n", + "Collecting tensorflow-text\r\n", + " Downloading tensorflow_text-2.3.0-cp37-cp37m-manylinux1_x86_64.whl (2.6 MB)\r\n", + "\u001b[K |████████████████████████████████| 2.6 MB 16.7 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: scipy in /opt/conda/lib/python3.7/site-packages (from trax) (1.4.1)\r\n", + "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from trax) (1.18.5)\r\n", + "Requirement already satisfied: tensorflow-datasets in /opt/conda/lib/python3.7/site-packages (from trax) (4.0.0)\r\n", + "Collecting jaxlib\r\n", + " Downloading jaxlib-0.1.56-cp37-none-manylinux2010_x86_64.whl (32.1 MB)\r\n", + "\u001b[K |████████████████████████████████| 32.1 MB 382 kB/s \r\n", + "\u001b[?25hCollecting t5\r\n", + " Downloading t5-0.7.0-py3-none-any.whl (171 kB)\r\n", + "\u001b[K |████████████████████████████████| 171 kB 44.6 MB/s \r\n", + "\u001b[?25hCollecting jax\r\n", + " Downloading jax-0.2.3.tar.gz (473 kB)\r\n", + "\u001b[K |████████████████████████████████| 473 kB 39.2 MB/s \r\n", + "\u001b[?25hCollecting tensorflow-gan\r\n", + " Downloading tensorflow_gan-2.0.0-py2.py3-none-any.whl (365 kB)\r\n", + "\u001b[K |████████████████████████████████| 365 kB 46.8 MB/s \r\n", + "\u001b[?25hCollecting tf-slim\r\n", + " Downloading tf_slim-1.1.0-py2.py3-none-any.whl (352 kB)\r\n", + "\u001b[K |████████████████████████████████| 352 kB 46.4 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: sympy in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (1.5.1)\r\n", + "Requirement already satisfied: gevent in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (1.5.0)\r\n", + "Requirement already satisfied: h5py in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (2.10.0)\r\n", + "Collecting gunicorn\r\n", + " Downloading gunicorn-20.0.4-py2.py3-none-any.whl (77 kB)\r\n", + "\u001b[K |████████████████████████████████| 77 kB 3.9 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (4.45.0)\r\n", + "Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (0.18.2)\r\n", + "Requirement already satisfied: flask in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (1.1.2)\r\n", + "Requirement already satisfied: opencv-python in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (4.4.0.44)\r\n", + "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (2.23.0)\r\n", + "Collecting mesh-tensorflow\r\n", + " Downloading mesh_tensorflow-0.1.17-py3-none-any.whl (342 kB)\r\n", + "\u001b[K |████████████████████████████████| 342 kB 39.7 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: oauth2client in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (4.1.3)\r\n", + "Collecting pypng\r\n", + " Downloading pypng-0.0.20.tar.gz (649 kB)\r\n", + "\u001b[K |████████████████████████████████| 649 kB 40.7 MB/s \r\n", + "\u001b[33mWARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ProtocolError('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))': /simple/dopamine-rl/\u001b[0m\r\n", + "\u001b[?25hCollecting dopamine-rl\r\n", + " Downloading dopamine_rl-3.1.8-py3-none-any.whl (117 kB)\r\n", + "\u001b[K |████████████████████████████████| 117 kB 46.8 MB/s \r\n", + "\u001b[?25hCollecting kfac\r\n", + " Downloading kfac-0.2.3-py2.py3-none-any.whl (191 kB)\r\n", + "\u001b[K |████████████████████████████████| 191 kB 44.7 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: tensorflow-addons in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (0.10.0)\r\n", + "Requirement already satisfied: google-api-python-client in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (1.8.0)\r\n", + "Requirement already satisfied: Pillow in /opt/conda/lib/python3.7/site-packages (from tensor2tensor->trax) (7.2.0)\r\n", + "Collecting bz2file\r\n", + " Downloading bz2file-0.98.tar.gz (11 kB)\r\n", + "Collecting tensorflow-probability==0.7.0\r\n", + " Downloading tensorflow_probability-0.7.0-py2.py3-none-any.whl (981 kB)\r\n", + "\u001b[K |████████████████████████████████| 981 kB 50.5 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from gym->trax) (1.5.0)\r\n", + "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /opt/conda/lib/python3.7/site-packages (from gym->trax) (1.3.0)\r\n", + "Requirement already satisfied: tensorflow<2.4,>=2.3.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-text->trax) (2.3.0)\r\n", + "Requirement already satisfied: promise in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (2.3)\r\n", + "Requirement already satisfied: termcolor in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (1.1.0)\r\n", + "Requirement already satisfied: attrs>=18.1.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (19.3.0)\r\n", + "Requirement already satisfied: protobuf>=3.6.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (3.13.0)\r\n", + "Requirement already satisfied: dm-tree in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (0.1.5)\r\n", + "Requirement already satisfied: tensorflow-metadata in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (0.24.0)\r\n", + "Requirement already satisfied: dill in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (0.3.2)\r\n", + "Requirement already satisfied: importlib-resources; python_version < \"3.9\" in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax) (3.0.0)\r\n", + "Requirement already satisfied: babel in /opt/conda/lib/python3.7/site-packages (from t5->trax) (2.8.0)\r\n", + "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from t5->trax) (1.1.3)\r\n", + "Collecting sacrebleu\r\n", + " Downloading sacrebleu-1.4.14-py3-none-any.whl (64 kB)\r\n", + "\u001b[K |████████████████████████████████| 64 kB 1.8 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: torch in /opt/conda/lib/python3.7/site-packages (from t5->trax) (1.6.0)\r\n", + "Collecting rouge-score\r\n", + " Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)\r\n", + "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (from t5->trax) (0.1.91)\r\n", + "Collecting tfds-nightly\r\n", + " Downloading tfds_nightly-4.0.1.dev202010180107-py3-none-any.whl (3.6 MB)\r\n", + "\u001b[K |████████████████████████████████| 3.6 MB 38.7 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from t5->trax) (0.23.2)\r\n", + "Requirement already satisfied: transformers>=2.7.0 in /opt/conda/lib/python3.7/site-packages (from t5->trax) (3.0.2)\r\n", + "Requirement already satisfied: nltk in /opt/conda/lib/python3.7/site-packages (from t5->trax) (3.2.4)\r\n", + "Requirement already satisfied: opt_einsum in /opt/conda/lib/python3.7/site-packages (from jax->trax) (3.3.0)\r\n", + "Requirement already satisfied: tensorflow-hub>=0.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow-gan->tensor2tensor->trax) (0.9.0)\r\n", + "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.7/site-packages (from sympy->tensor2tensor->trax) (1.1.0)\r\n", + "Requirement already satisfied: greenlet>=0.4.14; platform_python_implementation == \"CPython\" in /opt/conda/lib/python3.7/site-packages (from gevent->tensor2tensor->trax) (0.4.15)\r\n", + "Requirement already satisfied: setuptools>=3.0 in /opt/conda/lib/python3.7/site-packages (from gunicorn->tensor2tensor->trax) (46.1.3.post20200325)\r\n", + "Requirement already satisfied: click>=5.1 in /opt/conda/lib/python3.7/site-packages (from flask->tensor2tensor->trax) (7.1.1)\r\n", + "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/lib/python3.7/site-packages (from flask->tensor2tensor->trax) (2.11.2)\r\n", + "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/lib/python3.7/site-packages (from flask->tensor2tensor->trax) (1.1.0)\r\n", + "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/lib/python3.7/site-packages (from flask->tensor2tensor->trax) (1.0.1)\r\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->tensor2tensor->trax) (1.24.3)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->tensor2tensor->trax) (2020.6.20)\r\n", + "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->tensor2tensor->trax) (2.9)\r\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->tensor2tensor->trax) (3.0.4)\r\n", + "Requirement already satisfied: httplib2>=0.9.1 in /opt/conda/lib/python3.7/site-packages (from oauth2client->tensor2tensor->trax) (0.17.2)\r\n", + "Requirement already satisfied: pyasn1>=0.1.7 in /opt/conda/lib/python3.7/site-packages (from oauth2client->tensor2tensor->trax) (0.4.8)\r\n", + "Requirement already satisfied: pyasn1-modules>=0.0.5 in /opt/conda/lib/python3.7/site-packages (from oauth2client->tensor2tensor->trax) (0.2.7)\r\n", + "Requirement already satisfied: rsa>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from oauth2client->tensor2tensor->trax) (4.0)\r\n", + "Collecting pygame>=1.9.2\r\n", + " Downloading pygame-1.9.6-cp37-cp37m-manylinux1_x86_64.whl (11.4 MB)\r\n", + "\u001b[K |████████████████████████████████| 11.4 MB 36.3 MB/s \r\n", + "\u001b[?25hCollecting flax>=0.2.0\r\n", + " Downloading flax-0.2.2-py3-none-any.whl (148 kB)\r\n", + "\u001b[K |████████████████████████████████| 148 kB 57.7 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: typeguard>=2.7 in /opt/conda/lib/python3.7/site-packages (from tensorflow-addons->tensor2tensor->trax) (2.9.1)\r\n", + "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /opt/conda/lib/python3.7/site-packages (from google-api-python-client->tensor2tensor->trax) (0.0.3)\r\n", + "Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /opt/conda/lib/python3.7/site-packages (from google-api-python-client->tensor2tensor->trax) (1.17.0)\r\n", + "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from google-api-python-client->tensor2tensor->trax) (3.0.1)\r\n", + "Requirement already satisfied: google-auth>=1.4.1 in /opt/conda/lib/python3.7/site-packages (from google-api-python-client->tensor2tensor->trax) (1.14.0)\r\n", + "Requirement already satisfied: decorator in /opt/conda/lib/python3.7/site-packages (from tensorflow-probability==0.7.0->tensor2tensor->trax) (4.4.2)\r\n", + "Requirement already satisfied: google-pasta>=0.1.8 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (0.2.0)\r\n", + "Requirement already satisfied: gast==0.3.3 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (0.3.3)\r\n", + "Requirement already satisfied: tensorflow-estimator<2.4.0,>=2.3.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (2.3.0)\r\n", + "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (0.34.2)\r\n", + "Requirement already satisfied: wrapt>=1.11.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.11.2)\r\n", + "Requirement already satisfied: grpcio>=1.8.6 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.32.0)\r\n", + "Collecting tensorboard<3,>=2.3.0\r\n", + " Downloading tensorboard-2.3.0-py3-none-any.whl (6.8 MB)\r\n", + "\u001b[K |████████████████████████████████| 6.8 MB 31.1 MB/s \r\n", + "\u001b[?25hRequirement already satisfied: astunparse==1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.6.3)\r\n", + "Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.1.2)\r\n", + "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-metadata->tensorflow-datasets->trax) (1.52.0)\r\n", + "Requirement already satisfied: zipp>=0.4; python_version < \"3.8\" in /opt/conda/lib/python3.7/site-packages (from importlib-resources; python_version < \"3.9\"->tensorflow-datasets->trax) (3.1.0)\r\n", + "Requirement already satisfied: pytz>=2015.7 in /opt/conda/lib/python3.7/site-packages (from babel->t5->trax) (2019.3)\r\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->t5->trax) (2.8.1)\r\n", + "Requirement already satisfied: portalocker in /opt/conda/lib/python3.7/site-packages (from sacrebleu->t5->trax) (2.0.0)\r\n", + "Requirement already satisfied: typing-extensions; python_version < \"3.8\" in /opt/conda/lib/python3.7/site-packages (from tfds-nightly->t5->trax) (3.7.4.1)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->t5->trax) (2.1.0)\r\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->t5->trax) (0.14.1)\r\n", + "Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax) (0.0.43)\r\n", + "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax) (20.1)\r\n", + "Requirement already satisfied: tokenizers==0.8.1.rc1 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax) (0.8.1rc1)\r\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax) (3.0.10)\r\n", + "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax) (2020.4.4)\r\n", + "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask->tensor2tensor->trax) (1.1.1)\r\n", + "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.7/site-packages (from flax>=0.2.0->dopamine-rl->tensor2tensor->trax) (3.2.1)\r\n", + "Requirement already satisfied: msgpack in /opt/conda/lib/python3.7/site-packages (from flax>=0.2.0->dopamine-rl->tensor2tensor->trax) (1.0.0)\r\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth>=1.4.1->google-api-python-client->tensor2tensor->trax) (3.1.1)\r\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard<3,>=2.3.0->tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.7.0)\r\n", + "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard<3,>=2.3.0->tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (3.2.1)\r\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard<3,>=2.3.0->tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (0.4.1)\r\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->transformers>=2.7.0->t5->trax) (2.4.7)\r\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->flax>=0.2.0->dopamine-rl->tensor2tensor->trax) (1.2.0)\r\n", + "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.7/site-packages (from matplotlib->flax>=0.2.0->dopamine-rl->tensor2tensor->trax) (0.10.0)\r\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (1.2.0)\r\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow<2.4,>=2.3.0->tensorflow-text->trax) (3.0.1)\r\n", + "Building wheels for collected packages: jax, pypng, bz2file\r\n", + " Building wheel for jax (setup.py) ... \u001b[?25l-\b \b\\\b \b|\b \b/\b \bdone\r\n", + "\u001b[?25h Created wheel for jax: filename=jax-0.2.3-py3-none-any.whl size=542175 sha256=acf65b35147e9c85817bf410212d1a0eb79daf8248f3ad23830ca88af14d719a\r\n", + " Stored in directory: /root/.cache/pip/wheels/93/6f/2d/ee26ee4ada4c80f15c60eacb11a410d20f59e244bfea506111\r\n", + " Building wheel for pypng (setup.py) ... \u001b[?25l-\b \b\\\b \bdone\r\n", + "\u001b[?25h Created wheel for pypng: filename=pypng-0.0.20-py3-none-any.whl size=67162 sha256=460e3eb3bf48c8f0760c1fdb4951d593824fb90568d38564a2530f94d4064b83\r\n", + " Stored in directory: /root/.cache/pip/wheels/54/64/43/dfd10cf95dc1687dc5350e861321ecd9a5d76b7c3d96fa1dc6\r\n", + " Building wheel for bz2file (setup.py) ... \u001b[?25l-\b \b\\\b \bdone\r\n", + "\u001b[?25h Created wheel for bz2file: filename=bz2file-0.98-py3-none-any.whl size=6882 sha256=044df879703ac715458b305ea4fff507d5aafeb3da6f696155f39b8f60fbf0cb\r\n", + " Stored in directory: /root/.cache/pip/wheels/85/ce/8d/b5f76b602b16a8a39f2ded74189cf5f09fc4a87bea16c54a8b\r\n", + "Successfully built jax pypng bz2file\r\n", + "Installing collected packages: gin-config, tensorflow-probability, tensorflow-gan, tf-slim, gunicorn, mesh-tensorflow, pypng, pygame, jaxlib, jax, flax, dopamine-rl, kfac, bz2file, tensor2tensor, funcsigs, tensorflow-text, sacrebleu, rouge-score, tfds-nightly, t5, trax, tensorboard\r\n", + " Attempting uninstall: tensorflow-probability\r\n", + " Found existing installation: tensorflow-probability 0.11.1\r\n", + " Uninstalling tensorflow-probability-0.11.1:\r\n", + " Successfully uninstalled tensorflow-probability-0.11.1\r\n", + " Attempting uninstall: tensorboard\r\n", + " Found existing installation: tensorboard 2.2.0\r\n", + " Uninstalling tensorboard-2.2.0:\r\n", + " Successfully uninstalled tensorboard-2.2.0\r\n", + "\u001b[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.\r\n", + "\r\n", + "We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.\r\n", + "\r\n", + "pytorch-lightning 0.9.0 requires tensorboard==2.2.0, but you'll have tensorboard 2.3.0 which is incompatible.\r\n", + "kfac 0.2.3 requires tensorflow-probability==0.8, but you'll have tensorflow-probability 0.7.0 which is incompatible.\u001b[0m\r\n", + "Successfully installed bz2file-0.98 dopamine-rl-3.1.8 flax-0.2.2 funcsigs-1.0.2 gin-config-0.3.0 gunicorn-20.0.4 jax-0.2.3 jaxlib-0.1.56 kfac-0.2.3 mesh-tensorflow-0.1.17 pygame-1.9.6 pypng-0.0.20 rouge-score-0.0.4 sacrebleu-1.4.14 t5-0.7.0 tensor2tensor-1.15.7 tensorboard-2.3.0 tensorflow-gan-2.0.0 tensorflow-probability-0.7.0 tensorflow-text-2.3.0 tf-slim-1.1.0 tfds-nightly-4.0.1.dev202010180107 trax-1.3.5\r\n", + "\u001b[33mWARNING: You are using pip version 20.2.3; however, version 20.2.4 is available.\r\n", + "You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\r\n" + ] + } + ], + "source": [ + "!pip install trax" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.121469, + "end_time": "2020-10-19T05:24:41.120599", + "exception": false, + "start_time": "2020-10-19T05:24:40.999130", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Importing Packages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.117117, + "end_time": "2020-10-19T05:24:41.355694", + "exception": false, + "start_time": "2020-10-19T05:24:41.238577", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "In this notebook we will use the following packages:\n", + "\n", + "* [**Pandas**](https://pandas.pydata.org/) is a fast, powerful, flexible and easy to use open-source data analysis and manipulation tool, built on top of the Python programming language. It offers a fast and efficient DataFrame object for data manipulation with integrated indexing.\n", + "* [**os**](https://docs.python.org/3/library/os.html) module provides a portable way of using operating system dependent functionality.\n", + "* [**trax**](https://trax-ml.readthedocs.io/en/latest/trax.html) is an end-to-end library for deep learning that focuses on clear code and speed.\n", + "* [**random**](https://docs.python.org/3/library/random.html) module implements pseudo-random number generators for various distributions.\n", + "* [**itertools**](https://docs.python.org/3/library/itertools.html) module implements a number of iterator building blocks inspired by constructs from APL, Haskell, and SML. Each has been recast in a form suitable for Python." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2020-10-19T05:24:41.598509Z", + "iopub.status.busy": "2020-10-19T05:24:41.597670Z", + "iopub.status.idle": "2020-10-19T05:24:54.656423Z", + "shell.execute_reply": "2020-10-19T05:24:54.655287Z" + }, + "papermill": { + "duration": 13.181434, + "end_time": "2020-10-19T05:24:54.656623", + "exception": false, + "start_time": "2020-10-19T05:24:41.475189", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd \n", + "import os\n", + "import trax\n", + "import trax.fastmath.numpy as np\n", + "import random as rnd\n", + "from trax import fastmath\n", + "from trax import layers as tl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.118759, + "end_time": "2020-10-19T05:24:54.899617", + "exception": false, + "start_time": "2020-10-19T05:24:54.780858", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Loading the Data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.122704, + "end_time": "2020-10-19T05:24:55.144895", + "exception": false, + "start_time": "2020-10-19T05:24:55.022191", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "For this project, I've used the [gothic-literature](https://www.kaggle.com/charlesaverill/gothic-literature), [shakespeare-plays](https://www.kaggle.com/kingburrito666/shakespeare-plays) and [shakespeareonline](https://www.kaggle.com/kewagbln/shakespeareonline) datasets from the Kaggle library. \n", + "\n", + "We perform the following steps for loading in the data:\n", + "\n", + "* Iterate over all the directories in the `/kaggle/input/` directory\n", + "* Filter out `.txt` files\n", + "* Make a `lines` list containing the individual lines from all the datasets combined" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true, + "execution": { + "iopub.execute_input": "2020-10-19T05:24:55.385118Z", + "iopub.status.busy": "2020-10-19T05:24:55.384122Z", + "iopub.status.idle": "2020-10-19T05:24:55.716407Z", + "shell.execute_reply": "2020-10-19T05:24:55.715479Z" + }, + "papermill": { + "duration": 0.456359, + "end_time": "2020-10-19T05:24:55.716572", + "exception": false, + "start_time": "2020-10-19T05:24:55.260213", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "directories = os.listdir('/kaggle/input/')\n", + "lines = []\n", + "for directory in directories:\n", + " for filename in os.listdir(os.path.join('/kaggle/input',directory)):\n", + " if filename.endswith(\".txt\"):\n", + " with open(os.path.join(os.path.join('/kaggle/input',directory), filename)) as files:\n", + " for line in files: \n", + " processed_line = line.strip()\n", + " if processed_line:\n", + " lines.append(processed_line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.113664, + "end_time": "2020-10-19T05:24:55.951966", + "exception": false, + "start_time": "2020-10-19T05:24:55.838302", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Pre-Processing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.119888, + "end_time": "2020-10-19T05:24:56.194726", + "exception": false, + "start_time": "2020-10-19T05:24:56.074838", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Converting to Lowercase\n", + "\n", + "Converting all the characters in the `lines` list to **lowercase**." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:56.496346Z", + "iopub.status.busy": "2020-10-19T05:24:56.470575Z", + "iopub.status.idle": "2020-10-19T05:24:56.569027Z", + "shell.execute_reply": "2020-10-19T05:24:56.569637Z" + }, + "papermill": { + "duration": 0.253923, + "end_time": "2020-10-19T05:24:56.569875", + "exception": false, + "start_time": "2020-10-19T05:24:56.315952", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "for i, line in enumerate(lines):\n", + " lines[i] = line.lower()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.11122, + "end_time": "2020-10-19T05:24:56.795120", + "exception": false, + "start_time": "2020-10-19T05:24:56.683900", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Converting into Tensors\n", + "\n", + "Creating a function to convert each line into a tensor by converting each character into it's ASCII value. And adding a optional `EOS`(**End of statement**) character." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:57.032580Z", + "iopub.status.busy": "2020-10-19T05:24:57.029673Z", + "iopub.status.idle": "2020-10-19T05:24:57.037237Z", + "shell.execute_reply": "2020-10-19T05:24:57.036444Z" + }, + "papermill": { + "duration": 0.131432, + "end_time": "2020-10-19T05:24:57.037392", + "exception": false, + "start_time": "2020-10-19T05:24:56.905960", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def line_to_tensor(line, EOS_int=1):\n", + " \n", + " tensor = []\n", + " for c in line:\n", + " c_int = ord(c)\n", + " tensor.append(c_int)\n", + " \n", + " tensor.append(EOS_int)\n", + "\n", + " return tensor" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.109763, + "end_time": "2020-10-19T05:24:57.259043", + "exception": false, + "start_time": "2020-10-19T05:24:57.149280", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Creating a Batch Generator\n", + "\n", + "Here, we create a `batch_generator()` function to yield a batch and mask generator. We perform the following steps:\n", + "\n", + "* Shuffle the lines if not shuffled\n", + "* Convert the lines into a Tensor\n", + "* Pad the lines if it's less than the maximum length\n", + "* Generate a mask " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:57.491159Z", + "iopub.status.busy": "2020-10-19T05:24:57.490293Z", + "iopub.status.idle": "2020-10-19T05:24:57.503719Z", + "shell.execute_reply": "2020-10-19T05:24:57.502899Z" + }, + "papermill": { + "duration": 0.134497, + "end_time": "2020-10-19T05:24:57.503870", + "exception": false, + "start_time": "2020-10-19T05:24:57.369373", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def data_generator(batch_size, max_length, data_lines, line_to_tensor=line_to_tensor, shuffle=True):\n", + " \n", + " index = 0 \n", + " cur_batch = [] \n", + " num_lines = len(data_lines) \n", + " lines_index = [*range(num_lines)] \n", + "\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + " \n", + " while True:\n", + " \n", + " if index >= num_lines:\n", + " index = 0\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + " \n", + " line = data_lines[lines_index[index]] \n", + " \n", + " if len(line) < max_length:\n", + " cur_batch.append(line)\n", + " \n", + " index += 1\n", + " \n", + " if len(cur_batch) == batch_size:\n", + " \n", + " batch = []\n", + " mask = []\n", + " \n", + " for li in cur_batch:\n", + "\n", + " tensor = line_to_tensor(li)\n", + "\n", + " pad = [0] * (max_length - len(tensor))\n", + " tensor_pad = tensor + pad\n", + " batch.append(tensor_pad)\n", + "\n", + " example_mask = [0 if t == 0 else 1 for t in tensor_pad]\n", + " mask.append(example_mask)\n", + " \n", + " batch_np_arr = np.array(batch)\n", + " mask_np_arr = np.array(mask)\n", + " \n", + " \n", + " yield batch_np_arr, batch_np_arr, mask_np_arr\n", + " \n", + " cur_batch = []\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.113922, + "end_time": "2020-10-19T05:24:57.728762", + "exception": false, + "start_time": "2020-10-19T05:24:57.614840", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Defining the Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.110544, + "end_time": "2020-10-19T05:24:57.950897", + "exception": false, + "start_time": "2020-10-19T05:24:57.840353", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Gated Recurrent Unit\n", + "\n", + "This function generates a GRU Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* GRU Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:58.183193Z", + "iopub.status.busy": "2020-10-19T05:24:58.182383Z", + "iopub.status.idle": "2020-10-19T05:24:58.186370Z", + "shell.execute_reply": "2020-10-19T05:24:58.185685Z" + }, + "papermill": { + "duration": 0.124594, + "end_time": "2020-10-19T05:24:58.186525", + "exception": false, + "start_time": "2020-10-19T05:24:58.061931", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def GRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode), \n", + " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", + " [tl.GRU(n_units=d_model) for _ in range(n_layers)], \n", + " tl.Dense(n_units = vocab_size), \n", + " tl.LogSoftmax() \n", + " )\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.150132, + "end_time": "2020-10-19T05:24:58.463252", + "exception": false, + "start_time": "2020-10-19T05:24:58.313120", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Long Short Term Memory\n", + "\n", + "This function generates a LSTM Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* LSTM Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:58.713423Z", + "iopub.status.busy": "2020-10-19T05:24:58.712488Z", + "iopub.status.idle": "2020-10-19T05:24:58.717162Z", + "shell.execute_reply": "2020-10-19T05:24:58.716096Z" + }, + "papermill": { + "duration": 0.129976, + "end_time": "2020-10-19T05:24:58.717410", + "exception": false, + "start_time": "2020-10-19T05:24:58.587434", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def LSTMLM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode), \n", + " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", + " [tl.LSTM(n_units=d_model) for _ in range(n_layers)], \n", + " tl.Dense(n_units = vocab_size), \n", + " tl.LogSoftmax() \n", + " )\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.130305, + "end_time": "2020-10-19T05:24:58.971978", + "exception": false, + "start_time": "2020-10-19T05:24:58.841673", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Simple Recurrent Unit\n", + "\n", + "This function generates a SRU Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* SRU Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:59.219038Z", + "iopub.status.busy": "2020-10-19T05:24:59.218146Z", + "iopub.status.idle": "2020-10-19T05:24:59.221200Z", + "shell.execute_reply": "2020-10-19T05:24:59.221764Z" + }, + "papermill": { + "duration": 0.12795, + "end_time": "2020-10-19T05:24:59.221979", + "exception": false, + "start_time": "2020-10-19T05:24:59.094029", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def SRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode), \n", + " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", + " [tl.SRU(n_units=d_model) for _ in range(n_layers)], \n", + " tl.Dense(n_units = vocab_size), \n", + " tl.LogSoftmax() \n", + " )\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:24:59.461999Z", + "iopub.status.busy": "2020-10-19T05:24:59.460669Z", + "iopub.status.idle": "2020-10-19T05:24:59.465622Z", + "shell.execute_reply": "2020-10-19T05:24:59.466443Z" + }, + "papermill": { + "duration": 0.132413, + "end_time": "2020-10-19T05:24:59.466681", + "exception": false, + "start_time": "2020-10-19T05:24:59.334268", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serial[\n", + " ShiftRight(1)\n", + " Embedding_256_512\n", + " GRU_512\n", + " GRU_512\n", + " GRU_512\n", + " GRU_512\n", + " GRU_512\n", + " Dense_256\n", + " LogSoftmax\n", + "]\n", + "Serial[\n", + " ShiftRight(1)\n", + " Embedding_256_512\n", + " LSTM_512\n", + " LSTM_512\n", + " LSTM_512\n", + " LSTM_512\n", + " LSTM_512\n", + " Dense_256\n", + " LogSoftmax\n", + "]\n", + "Serial[\n", + " ShiftRight(1)\n", + " Embedding_256_512\n", + " SRU_512\n", + " SRU_512\n", + " SRU_512\n", + " SRU_512\n", + " SRU_512\n", + " Dense_256\n", + " LogSoftmax\n", + "]\n" + ] + } + ], + "source": [ + "GRUmodel = GRULM(n_layers = 5)\n", + "LSTMmodel = LSTMLM(n_layers = 5)\n", + "SRUmodel = SRULM(n_layers = 5)\n", + "print(GRUmodel)\n", + "print(LSTMmodel)\n", + "print(SRUmodel)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.117255, + "end_time": "2020-10-19T05:24:59.712882", + "exception": false, + "start_time": "2020-10-19T05:24:59.595627", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Hyperparameters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.113458, + "end_time": "2020-10-19T05:24:59.939569", + "exception": false, + "start_time": "2020-10-19T05:24:59.826111", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we declare `the batch_size` and the `max_length` hyperparameters for the model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:25:00.173212Z", + "iopub.status.busy": "2020-10-19T05:25:00.172118Z", + "iopub.status.idle": "2020-10-19T05:25:00.176348Z", + "shell.execute_reply": "2020-10-19T05:25:00.175587Z" + }, + "papermill": { + "duration": 0.121757, + "end_time": "2020-10-19T05:25:00.176474", + "exception": false, + "start_time": "2020-10-19T05:25:00.054717", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "batch_size = 32\n", + "max_length = 64" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.111425, + "end_time": "2020-10-19T05:25:00.399880", + "exception": false, + "start_time": "2020-10-19T05:25:00.288455", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Creating Evaluation and Training Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:25:00.637648Z", + "iopub.status.busy": "2020-10-19T05:25:00.634400Z", + "iopub.status.idle": "2020-10-19T05:25:00.641032Z", + "shell.execute_reply": "2020-10-19T05:25:00.641698Z" + }, + "papermill": { + "duration": 0.130539, + "end_time": "2020-10-19T05:25:00.641885", + "exception": false, + "start_time": "2020-10-19T05:25:00.511346", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "eval_lines = lines[-1000:] # Create a holdout validation set\n", + "lines = lines[:-1000] # Leave the rest for training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.112994, + "end_time": "2020-10-19T05:25:00.871007", + "exception": false, + "start_time": "2020-10-19T05:25:00.758013", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Training the Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.112218, + "end_time": "2020-10-19T05:25:01.096544", + "exception": false, + "start_time": "2020-10-19T05:25:00.984326", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we create a function to train the models. This function does the following:\n", + "\n", + "* Creating a Train and Evaluation Generator that cycles infinetely using the `itertools` module\n", + "* Train the Model using Adam Optimizer\n", + "* Use the Accuracy Metric for Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:25:01.335062Z", + "iopub.status.busy": "2020-10-19T05:25:01.330866Z", + "iopub.status.idle": "2020-10-19T05:25:01.339390Z", + "shell.execute_reply": "2020-10-19T05:25:01.338695Z" + }, + "papermill": { + "duration": 0.130503, + "end_time": "2020-10-19T05:25:01.339549", + "exception": false, + "start_time": "2020-10-19T05:25:01.209046", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax.supervised import training\n", + "import itertools\n", + "\n", + "def train_model(model, data_generator, batch_size=32, max_length=64, lines=lines, eval_lines=eval_lines, n_steps=10, output_dir = 'model/'): \n", + "\n", + " \n", + " bare_train_generator = data_generator(batch_size, max_length, data_lines=lines)\n", + " infinite_train_generator = itertools.cycle(bare_train_generator)\n", + " \n", + " bare_eval_generator = data_generator(batch_size, max_length, data_lines=eval_lines)\n", + " infinite_eval_generator = itertools.cycle(bare_eval_generator)\n", + " \n", + " train_task = training.TrainTask(\n", + " labeled_data=infinite_train_generator, \n", + " loss_layer=tl.CrossEntropyLoss(), \n", + " optimizer=trax.optimizers.Adam(0.0005) \n", + " )\n", + "\n", + " eval_task = training.EvalTask(\n", + " labeled_data=infinite_eval_generator, \n", + " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", + " n_eval_batches=3 \n", + " )\n", + " \n", + " training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir = output_dir\n", + " )\n", + "\n", + " training_loop.run(n_steps=n_steps)\n", + " \n", + " return training_loop\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:25:01.602437Z", + "iopub.status.busy": "2020-10-19T05:25:01.601617Z", + "iopub.status.idle": "2020-10-19T05:26:21.063884Z", + "shell.execute_reply": "2020-10-19T05:26:21.062700Z" + }, + "papermill": { + "duration": 79.597768, + "end_time": "2020-10-19T05:26:21.064134", + "exception": false, + "start_time": "2020-10-19T05:25:01.466366", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.\n", + " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Step 1: Ran 1 train steps in 20.15 secs\n", + "Step 1: train CrossEntropyLoss | 5.54517841\n", + "Step 1: eval CrossEntropyLoss | 5.54224094\n", + "Step 1: eval Accuracy | 0.20141485\n" + ] + } + ], + "source": [ + "GRU_training_loop = train_model(GRUmodel, data_generator,n_steps=10, output_dir = 'model/GRU')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:26:21.431594Z", + "iopub.status.busy": "2020-10-19T05:26:21.430465Z", + "iopub.status.idle": "2020-10-19T05:27:55.049767Z", + "shell.execute_reply": "2020-10-19T05:27:55.049034Z" + }, + "papermill": { + "duration": 93.801876, + "end_time": "2020-10-19T05:27:55.049974", + "exception": false, + "start_time": "2020-10-19T05:26:21.248098", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Step 1: Ran 1 train steps in 22.91 secs\n", + "Step 1: train CrossEntropyLoss | 5.76504803\n", + "Step 1: eval CrossEntropyLoss | 4.79372247\n", + "Step 1: eval Accuracy | 0.18692371\n" + ] + } + ], + "source": [ + "LSTM_training_loop = train_model(LSTMmodel, data_generator, n_steps = 10, output_dir = 'model/LSTM')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "execution": { + "iopub.execute_input": "2020-10-19T05:27:55.406482Z", + "iopub.status.busy": "2020-10-19T05:27:55.405074Z", + "iopub.status.idle": "2020-10-19T05:28:36.239692Z", + "shell.execute_reply": "2020-10-19T05:28:36.238806Z" + }, + "papermill": { + "duration": 41.004194, + "end_time": "2020-10-19T05:28:36.239938", + "exception": false, + "start_time": "2020-10-19T05:27:55.235744", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Step 1: Ran 1 train steps in 11.45 secs\n", + "Step 1: train CrossEntropyLoss | 5.54126787\n", + "Step 1: eval CrossEntropyLoss | 5.51660713\n", + "Step 1: eval Accuracy | 0.08041244\n" + ] + } + ], + "source": [ + "SRU_training_loop = train_model(SRUmodel, data_generator, n_steps = 10, output_dir = 'model/SRU')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 297.094983, + "end_time": "2020-10-19T05:28:36.576660", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-10-19T05:23:39.481677", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}