diff --git a/Chatbot/Reformer_Chatbot.ipynb b/Chatbot/Reformer_Chatbot.ipynb new file mode 100644 index 0000000..0e9b500 --- /dev/null +++ b/Chatbot/Reformer_Chatbot.ipynb @@ -0,0 +1,1335 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 451 + }, + "colab_type": "code", + "id": "aV4zpTnSVFIp", + "outputId": "e3a85dd1-e375-4636-ea62-b9b403f0952a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 \n", + "trax 1.3.4\n", + "\u001b[33mWARNING: You are using pip version 20.1.1; however, version 20.2.3 is available.\n", + "You should consider upgrading via the '/opt/conda/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "import json\n", + "import random\n", + "import numpy as np\n", + "from termcolor import colored\n", + "\n", + "import trax \n", + "from trax import layers as tl\n", + "from trax.supervised import training\n", + "!pip list | grep trax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# filename of the MultiWOZ dialogue dataset\n", + "DATA_FILE = 'data.json'\n", + "\n", + "# data directory\n", + "DATA_DIR = './data'\n", + "\n", + "# dictionary where we will load the dialogue dataset\n", + "DIALOGUE_DB = {}\n", + "\n", + "# vocabulary filename\n", + "VOCAB_FILE = 'en_32k.subword'\n", + "\n", + "# vocabulary file directory\n", + "VOCAB_DIR = 'data/vocabs'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 167 + }, + "colab_type": "code", + "id": "K58I5vFB7GlP", + "outputId": "3d086ea4-7898-4870-b52f-f362cb02e118" + }, + "outputs": [], + "source": [ + "# help function to load a JSON file\n", + "def load_json(directory, file):\n", + " with open(f'{directory}/{file}') as file: \n", + " db = json.load(file)\n", + " return db\n", + "\n", + "# load the dialogue data set into our dictionary\n", + "DIALOGUE_DB = load_json(DATA_DIR, DATA_FILE)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 215 + }, + "colab_type": "code", + "id": "VGBnUfEk8p9x", + "outputId": "4b364506-1088-4f00-be0c-4fefa892dc4e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The number of dialogues is: 10438\n" + ] + } + ], + "source": [ + "print(f'The number of dialogues is: {len(DIALOGUE_DB)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The dialogues are composed of multiple files and the filenames are used as keys in our dictionary. Those with multi-domain dialogues have \"MUL\" in their filenames while single domain dialogues have either \"SNG\" or \"WOZ\"." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['SNG01856.json', 'SNG0129.json', 'PMUL1635.json', 'MUL2168.json', 'SNG0073.json', 'SNG01445.json', 'MUL2105.json']\n" + ] + } + ], + "source": [ + "# print 7 keys from the dataset to see the filenames\n", + "print(list(DIALOGUE_DB.keys())[0:7]) " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "5KYeQLnG8p96", + "outputId": "b22f570d-a7b0-4b92-ba68-0b7236e61051" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['goal', 'log'])\n" + ] + } + ], + "source": [ + "# get keys of the fifth file in the list above\n", + "print(DIALOGUE_DB['SNG0073.json'].keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 122 + }, + "colab_type": "code", + "id": "PPPWwQ2s8p9_", + "outputId": "7e8efa2d-821a-44c8-902d-2c722baf5b4c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'taxi': {'info': {'leaveAt': '17:15',\n", + " 'destination': 'pizza hut fen ditton',\n", + " 'departure': \"saint john's college\"},\n", + " 'reqt': ['car type', 'phone'],\n", + " 'fail_info': {}},\n", + " 'police': {},\n", + " 'hospital': {},\n", + " 'hotel': {},\n", + " 'attraction': {},\n", + " 'train': {},\n", + " 'message': [\"You want to book a taxi. The taxi should go to pizza hut fen ditton and should depart from saint john's college\",\n", + " \"The taxi should leave after 17:15\",\n", + " \"Make sure you get car type and contact number\"],\n", + " 'restaurant': {}}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "DIALOGUE_DB['SNG0073.json']['goal']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "B4N8RtWu8p-C" + }, + "source": [ + "The `log` on the other hand contains the dialog. It is a list of dictionaries and each element of this list contains several descriptions as well. Let's look at an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'text': \"I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.\",\n", + " 'metadata': {},\n", + " 'dialog_act': {'Taxi-Inform': [['Dest', 'pizza hut fen ditton'],\n", + " ['Depart', \"saint john 's college\"]]},\n", + " 'span_info': [['Taxi-Inform', 'Dest', 'pizza hut fen ditton', 11, 14],\n", + " ['Taxi-Inform', 'Depart', \"saint john 's college\", 6, 9]]}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# get first element of the log list\n", + "DIALOGUE_DB['SNG0073.json']['log'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Person 1: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.\n", + " Person 2: What time do you want to leave and what time do you want to arrive by?\n" + ] + } + ], + "source": [ + "print(' Person 1: ', DIALOGUE_DB['SNG0073.json']['log'][0]['text'])\n", + "print(' Person 2: ',DIALOGUE_DB['SNG0073.json']['log'][1]['text'])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_conversation(file, data_db):\n", + " '''\n", + " Args:\n", + " file (string): filename of the dialogue file saved as json\n", + " data_db (dict): dialogue database\n", + " \n", + " Returns:\n", + " string: A string containing the 'text' fields of data[file]['log'][x]\n", + " '''\n", + " \n", + " # initialize empty string\n", + " result = ''\n", + " \n", + " # get length of file's log list\n", + " len_msg_log = len(data_db[file]['log'])\n", + " \n", + " # set the delimiter strings\n", + " delimiter_1 = ' Person 1: '\n", + " delimiter_2 = ' Person 2: '\n", + " \n", + " # loop over the file's log list\n", + " for i in range(len_msg_log):\n", + " \n", + " \n", + " # get i'th element of file log list\n", + " cur_log = data_db[file]['log'][i]\n", + " \n", + " # check if i is even\n", + " if i%2 == 0: \n", + " # append the 1st delimiter string\n", + " result += delimiter_1\n", + " else: \n", + " # append the 2nd delimiter string\n", + " result += delimiter_2\n", + " \n", + " # append the message text from the log\n", + " result += cur_log['text']\n", + "\n", + " return result\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ugvx0noP8p-G" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.\n", + "Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.\n" + ] + } + ], + "source": [ + "file = 'SNG01856.json'\n", + "conversation = get_conversation(file, DIALOGUE_DB)\n", + "\n", + "# print raw output\n", + "print(conversation)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31mPerson 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel \u001b[0m\n", + "\u001b[32mPerson 2: Okay, do you have a specific area you want to stay in? \u001b[0m\n", + "\u001b[31mPerson 1: no, i just need to make sure it's cheap. oh, and i need parking \u001b[0m\n", + "\u001b[32mPerson 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? \u001b[0m\n", + "\u001b[31mPerson 1: Yes, please. 6 people 3 nights starting on tuesday. \u001b[0m\n", + "\u001b[32mPerson 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? \u001b[0m\n", + "\u001b[31mPerson 1: how about only 2 nights. \u001b[0m\n", + "\u001b[32mPerson 2: Booking was successful.\n", + "Reference number is : 7GAWK763. Anything else I can do for you? \u001b[0m\n", + "\u001b[31mPerson 1: No, that will be all. Good bye. \u001b[0m\n", + "\u001b[32mPerson 2: Thank you for using our services.\u001b[0m\n" + ] + } + ], + "source": [ + "def print_conversation(conversation):\n", + " \n", + " delimiter_1 = 'Person 1: '\n", + " delimiter_2 = 'Person 2: '\n", + " \n", + " split_list_d1 = conversation.split(delimiter_1)\n", + " \n", + " for sublist in split_list_d1[1:]:\n", + " split_list_d2 = sublist.split(delimiter_2)\n", + " print(colored(f'Person 1: {split_list_d2[0]}', 'red'))\n", + " \n", + " if len(split_list_d2) > 1:\n", + " print(colored(f'Person 2: {split_list_d2[1]}', 'green'))\n", + "\n", + " \n", + "print_conversation(conversation)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "colab_type": "code", + "id": "Rs2R8q1d8p-K", + "outputId": "8a2f4e3f-4516-449f-9648-a5970707cfc9" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'text': 'am looking for a place to to stay that has cheap price range it should be in a type of hotel',\n", + " 'metadata': {},\n", + " 'dialog_act': {'Hotel-Inform': [['Type', 'hotel'], ['Price', 'cheap']]},\n", + " 'span_info': [['Hotel-Inform', 'Type', 'hotel', 20, 20],\n", + " ['Hotel-Inform', 'Price', 'cheap', 10, 10]]}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "DIALOGUE_DB['SNG01856.json']['log'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + }, + "colab_type": "code", + "id": "HQmYUcsi8p-O", + "outputId": "5730c55f-63da-42a8-935e-6eeb17f6f791" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'address': 'pool way, whitehill road, off newmarket road', 'area': 'east', 'entrance fee': '?', 'id': '1', 'location': [52.208789, 0.154883], 'name': 'abbey pool and astroturf pitch', 'openhours': '?', 'phone': '01223902088', 'postcode': 'cb58nt', 'pricerange': '?', 'type': 'swimmingpool'}\n" + ] + } + ], + "source": [ + "# this is an example of the attractions file\n", + "attraction_file = open('data/attraction_db.json')\n", + "attractions = json.load(attraction_file)\n", + "print(attractions[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "I5kTg4uX8p-R", + "outputId": "3dacc4ff-4f05-4ae6-d099-33d1b3a6fa2a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'department': 'neurosciences critical care unit', 'id': 0, 'phone': '01223216297'}\n" + ] + } + ], + "source": [ + "# this is an example of the hospital file\n", + "hospital_file = open('data/hospital_db.json')\n", + "hospitals = json.load(hospital_file)\n", + "print(hospitals[0]) # feel free to index into other indices" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + }, + "colab_type": "code", + "id": "B5knaAEc8p-U", + "outputId": "ee0110c2-b2c2-4584-bd42-21f75109a579" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'address': '124 tenison road', 'area': 'east', 'internet': 'yes', 'parking': 'no', 'id': '0', 'location': [52.1963733, 0.1987426], 'name': 'a and b guest house', 'phone': '01223315702', 'postcode': 'cb12dp', 'price': {'double': '70', 'family': '90', 'single': '50'}, 'pricerange': 'moderate', 'stars': '4', 'takesbookings': 'yes', 'type': 'guesthouse'}\n" + ] + } + ], + "source": [ + "# this is an example of the hotel file\n", + "hotel_file = open('data/hotel_db.json')\n", + "hotels = json.load(hotel_file)\n", + "print(hotels[0]) # feel free to index into other indices" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "t-Rk01Mv8p-a", + "outputId": "8977e17e-2fc3-4073-abb8-fcf5cef3cfaf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'name': 'Parkside Police Station', 'address': 'Parkside, Cambridge', 'id': 0, 'phone': '01223358966'}\n" + ] + } + ], + "source": [ + "# this is an example of the police file\n", + "police_file = open('data/police_db.json')\n", + "police = json.load(police_file)\n", + "print(police[0]) # feel free to index into other indices" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + }, + "colab_type": "code", + "id": "u-G9pD8g8p-d", + "outputId": "1dba6598-b9b6-4fc8-91d2-f844b98e45fa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'address': 'Regent Street City Centre', 'area': 'centre', 'food': 'italian', 'id': '19210', 'introduction': 'Pizza hut is a large chain with restaurants nationwide offering convenience pizzas pasta and salads to eat in or take away', 'location': [52.20103, 0.126023], 'name': 'pizza hut city centre', 'phone': '01223323737', 'postcode': 'cb21ab', 'pricerange': 'cheap', 'type': 'restaurant'}\n" + ] + } + ], + "source": [ + "# this is an example of a restuarant file\n", + "restaurant_file = open('data/restaurant_db.json')\n", + "restaurants = json.load(restaurant_file)\n", + "print(restaurants[0]) # feel free to index into other indices" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 181 + }, + "colab_type": "code", + "id": "2H8pB_yI8p-g", + "outputId": "aa039a49-3ed3-4f4d-fa4f-c2619de3dc99" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#####################################################\n", + "#####################################################\n", + "# Copyright Cambridge Dialogue Systems Group, 2018 #\n", + "#####################################################\n", + "#####################################################\n", + "\n", + "Dataset contains the following files:\n", + "1. data.json: the woz dialogue dataset, which contains the conversation users and wizards, as well as a set of coarse labels for each user turn. This file contains both system and user dialogue acts annotated at the turn level. Files with multi-domain dialogues have \"MUL\" in their names. Single domain dialogues have either \"SNG\" or \"WOZ\" in their names.\n", + "2. restaurant_db.json: the Cambridge restaurant database file, containing restaurants in the Cambridge UK area and a set of attributes.\n", + "3. attraction_db.json: the Cambridge attraction database file, contining attractions in the Cambridge UK area and a set of attributes.\n", + "4. hotel_db.json: the Cambridge hotel database file, containing hotels in the Cambridge UK area and a set of attributes.\n", + "5. train_db.json: the Cambridge train (with artificial connections) database file, containing trains in the Cambridge UK area and a set of attributes.\n", + "6. hospital_db.json: the Cambridge hospital database file, contatining information about departments.\n", + "7. police_db.json: the Cambridge police station information.\n", + "8. taxi_db.json: slot-value list for taxi domain.\n", + "9. valListFile.txt: list of dialogues for validation.\n", + "10. testListFile.txt: list of dialogues for testing.\n", + "11. system_acts.json:\n", + " There are 6 domains ('Booking', 'Restaurant', 'Hotel', 'Attraction', 'Taxi', 'Train') and 1 dummy domain ('general').\n", + " A domain-dependent dialogue act is defined as a domain token followed by a domain-independent dialogue act, e.g. 'Hotel-inform' means it is an 'inform' act in the Hotel domain.\n", + " Dialogue acts which cannot take slots, e.g., 'good bye', are defined under the 'general' domain.\n", + " A slot-value pair defined as a list with two elements. The first element is slot token and the second one is its value.\n", + " If a dialogue act takes no slots, e.g., dialogue act 'offer booking' for an utterance 'would you like to take a reservation?', its slot-value pair is ['none', 'none']\n", + " There are four types of values:\n", + " 1) If a slot takes a binary value, e.g., 'has Internet' or 'has park', the value is either 'yes' or 'no'.\n", + " 2) If a slot is under the act 'request', e.g., 'request' about 'area', the value is expressed as '?'.\n", + " 3) The value that appears in the utterance e.g., the name of a restaurant.\n", + " 4) If for some reason the turn does not have an annotation then it is labeled as \"No Annotation.\"\n", + "12. ontology.json: Data-based ontology containing all the values for the different slots in the domains.\n", + "13. slot_descriptions.json: A collection of human-written slot descriptions for each slot in the dataset. Each slot has at least two descriptions.\n", + "14. tokenization.md: A description of the tokenization preprocessing we had to perform to maintain consistency between the dialogue act annotations of DSTC 8 Track 1 and the existing MultiWOZ 2.0 data. \n", + "\n" + ] + } + ], + "source": [ + "with open('data/README') as file:\n", + " print(file.read())" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 71 + }, + "colab_type": "code", + "id": "IrnQ9eNV8p-k", + "outputId": "2b159dae-78be-4a19-df41-b9e620216d43" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.\n", + "Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.\n" + ] + } + ], + "source": [ + "# the keys are the file names\n", + "all_files = DIALOGUE_DB.keys()\n", + "\n", + "# initialize empty list\n", + "untokenized_data = []\n", + "\n", + "# loop over all files\n", + "for file in all_files:\n", + " # returns a string delimited by Person 1 and Person 2\n", + " result = get_conversation(file, DIALOGUE_DB)\n", + " \n", + " # append to the list\n", + " untokenized_data.append(result)\n", + "\n", + "print(untokenized_data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "buE0b8bjx_p_", + "outputId": "cb73a95b-488b-4d1d-9c20-5e98ca71f9d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of conversations in the data set: 10438\n", + "number of conversations in train set: 9917\n", + "number of conversations in eval set: 521\n" + ] + } + ], + "source": [ + "# shuffle the list we generated above\n", + "random.shuffle(untokenized_data)\n", + "\n", + "# define a cutoff \n", + "# convert to int because we will use it as a list index\n", + "cut_off = int(len(untokenized_data) * .05)\n", + "\n", + "# slice the list. the last elements after the cut_off value will be the eval set. the rest is for training. \n", + "train_data, eval_data = untokenized_data[:-cut_off], untokenized_data[-cut_off:]\n", + "\n", + "print(f'number of conversations in the data set: {len(untokenized_data)}')\n", + "print(f'number of conversations in train set: {len(train_data)}')\n", + "print(f'number of conversations in eval set: {len(eval_data)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "def stream(data):\n", + " # loop over the entire data\n", + " while True:\n", + " # get a random element\n", + " d = random.choice(data)\n", + " \n", + " # yield a tuple pair of identical values \n", + " # (i.e. our inputs to the model will also be our targets during training)\n", + " yield (d, d)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "uZgK5FAAWwOu" + }, + "outputs": [], + "source": [ + "\n", + "data_pipeline = trax.data.Serial(\n", + " # randomize the stream\n", + " trax.data.Shuffle(),\n", + " \n", + " # tokenize the data\n", + " trax.data.Tokenize(vocab_dir=VOCAB_DIR,\n", + " vocab_file=VOCAB_FILE),\n", + " \n", + " # filter too long sequences\n", + " trax.data.FilterByLength(2048),\n", + " \n", + " # bucket by length\n", + " trax.data.BucketByLength(boundaries=[128, 256, 512, 1024],\n", + " batch_sizes=[16, 8, 4, 2, 1]),\n", + " \n", + " # add loss weights but do not add it to the padding tokens (i.e. 0)\n", + " trax.data.AddLossWeights(id_to_mask=0)\n", + ")\n", + "\n", + "# apply the data pipeline to our train and eval sets\n", + "train_stream = data_pipeline(stream(train_data))\n", + "eval_stream = data_pipeline(stream(eval_data))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 88 + }, + "colab_type": "code", + "id": "9iBQEvhLYRot", + "outputId": "78659fd2-4633-47bc-ebe8-3ae3a6e2eab3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input shape: (4, 512)\n", + " Person 1: I need assistance finding a restaurant that is cheap and serves british food, can you help me? Person 2: I am sorry, I do not have any cheaply priced British restaurants in the city. Would you like a different type of food or maybe a different price range? Person 1: Do you have any Asian Oriental restaurants in the same price range? Person 2: Yes I found 2 cheap asian oriental places to eat. The dojo noodle bar and j restaurant. Would you like me to book a table for you? Person 1: Yes. Please book me a table at the dojo noodle bar for 7 people at 19:00 on wednesday. Person 2: I couldn't get you a table at Dojo, I can get you one at J Restaurant, if you want? Person 1: That works. Same parameters, please. I need the reference number too. Person 2: Booked! Your table will be reserved for 15 minutes. Reference number: F10EIGQ3. Person 1: Great thank you for all your help. Person 2: Do you need anything else? Person 1: That was all thank you. Person 2: Okay great. Thanks for calling and enjoy your dinner.\n" + ] + } + ], + "source": [ + "# the stream generators will yield (input, target, weights). \n", + "inp, _, _ = next(train_stream)\n", + "\n", + "# print the shape. format is (batch size, token length)\n", + "print(\"input shape: \", inp.shape)\n", + "\n", + "# detokenize the first element\n", + "print(trax.data.detokenize(inp[0], vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "adX2eU762BkF" + }, + "outputs": [], + "source": [ + "def reversible_layer_forward(x, f, g):\n", + " \"\"\"\n", + " Args: \n", + " x (np.array): an input vector or matrix\n", + " f (function): a function which operates on a vector/matrix\n", + " g (function): a function which operates on a vector/matrix\n", + " Returns: \n", + " y (np.array): an output vector or matrix whose form is determined by 'x', f and g\n", + " \"\"\"\n", + " # split the input vector into two (* along the last axis because it is the depth dimension)\n", + " x1, x2 = np.split(x, 2, axis=-1) \n", + " \n", + " y1 = x1 + f(x2)\n", + " \n", + " y2 = x2 + g(y1)\n", + " \n", + " # concatenate y1 and y2 along the depth dimension. be sure output is of type np.ndarray\n", + " y = np.concatenate([y1, y2], axis=-1)\n", + " \n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0xTCG9WlaiiO" + }, + "outputs": [], + "source": [ + "\n", + "def reversible_layer_reverse(y, f, g):\n", + " \"\"\"\n", + " Args: \n", + " y (np.array): an input vector or matrix\n", + " f (function): a function which operates on a vector/matrix of the form of 'y'\n", + " g (function): a function which operates on a vector/matrix of the form of 'y'\n", + " Returns: \n", + " y (np.array): an output vector or matrix whose form is determined by 'y', f and g\n", + " \"\"\"\n", + " \n", + " # split the input vector into two (* along the last axis because it is the depth dimension)\n", + " y1, y2 = np.split(y, 2, axis=-1)\n", + " \n", + " \n", + " x2 = y2 - g(y1)\n", + " \n", + " x1 = y1 - f(x2)\n", + " \n", + " # concatenate x1 and x2 along the depth dimension\n", + " x = np.concatenate([x1, x2], axis=-1)\n", + " \n", + " return x\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RidbAcoR6duP" + }, + "outputs": [], + "source": [ + "def ReformerLM(vocab_size=33000, n_layers=2, mode='train', attention_type=tl.SelfAttention):\n", + " \"\"\"\n", + " Args: \n", + " vocab_size (int): size of the vocabulary\n", + " n_layers (int): number of decoder layers\n", + " mode (string): setting of the model which can be 'train', 'eval', or 'predict' \n", + " attention_type(class): attention class to use \n", + " Returns: \n", + " model (ReformerLM): a reformer language model implemented in Trax\n", + " \"\"\" \n", + "\n", + " # initialize an instance of Trax's ReformerLM class\n", + " model = trax.models.reformer.ReformerLM( \n", + " # set vocab size\n", + " vocab_size=vocab_size,\n", + " # set number of layers\n", + " n_layers=n_layers,\n", + " # set mode\n", + " mode=mode,\n", + " # set attention type\n", + " attention_type=attention_type\n", + " )\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serial[\n", + " ShiftRight(1)\n", + " Embedding_train_512\n", + " Dropout\n", + " PositionalEncoding\n", + " Dup_out2\n", + " ReversibleSerial_in2_out2[\n", + " ReversibleHalfResidualV2_in2_out2[\n", + " Serial[\n", + " LayerNorm\n", + " ]\n", + " SelfAttention\n", + " ]\n", + " ReversibleSwap_in2_out2\n", + " ReversibleHalfResidualV2_in2_out2[\n", + " Serial[\n", + " LayerNorm\n", + " Dense_2048\n", + " Dropout\n", + " FastGelu\n", + " Dense_512\n", + " Dropout\n", + " ]\n", + " ]\n", + " ReversibleSwap_in2_out2\n", + " ReversibleHalfResidualV2_in2_out2[\n", + " Serial[\n", + " LayerNorm\n", + " ]\n", + " SelfAttention\n", + " ]\n", + " ReversibleSwap_in2_out2\n", + " ReversibleHalfResidualV2_in2_out2[\n", + " Serial[\n", + " LayerNorm\n", + " Dense_2048\n", + " Dropout\n", + " FastGelu\n", + " Dense_512\n", + " Dropout\n", + " ]\n", + " ]\n", + " ReversibleSwap_in2_out2\n", + " ]\n", + " Concatenate_in2\n", + " LayerNorm\n", + " Dropout\n", + " Dense_train\n", + " LogSoftmax\n", + "]\n" + ] + } + ], + "source": [ + "# display the model\n", + "temp_model = ReformerLM('train')\n", + "print(str(temp_model))\n", + "\n", + "# free memory\n", + "del temp_model " + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tQehGhoD4Psl" + }, + "outputs": [], + "source": [ + "\n", + "def training_loop(ReformerLM, train_gen, eval_gen, output_dir = \"./model/\"):\n", + " \"\"\"\n", + " Args:\n", + " ReformerLM: the Reformer language model you are building\n", + " train_gen (generator): train data generator.\n", + " eval_gen (generator): Validation generator. \n", + " output_dir (string): Path to save the model output. Defaults to './model/'.\n", + "\n", + " Returns:\n", + " trax.supervised.training.Loop: Training loop for the model.\n", + " \"\"\"\n", + "\n", + " # use the warmup_and_rsqrt_decay learning rate schedule\n", + " lr_schedule = trax.lr.warmup_and_rsqrt_decay(\n", + " n_warmup_steps=1000, max_value=0.01)\n", + "\n", + " \n", + " # define the train task\n", + " train_task = training.TrainTask( \n", + " # labeled data\n", + " labeled_data=train_gen,\n", + " # loss layer\n", + " loss_layer=tl.CrossEntropyLoss(),\n", + " # optimizer\n", + " optimizer=trax.optimizers.Adam(0.01),\n", + " # lr_schedule\n", + " lr_schedule=lr_schedule,\n", + " # n_steps\n", + " n_steps_per_checkpoint=10\n", + " )\n", + "\n", + " # define the eval task\n", + " eval_task = training.EvalTask( \n", + " # labeled data\n", + " labeled_data=eval_gen,\n", + " # metrics\n", + " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]\n", + " )\n", + "\n", + " loop = training.Loop(ReformerLM(mode='train'),\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)\n", + " return loop" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n" + ] + } + ], + "source": [ + "\n", + "test_loop = training_loop(ReformerLM, train_stream, eval_stream)\n", + "train_task = test_loop._task\n", + "eval_task = test_loop._eval_task\n", + "\n", + "print(train_task)\n", + "print(eval_task)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "# define the `predict_mem_len` and `predict_drop_len` of tl.SelfAttention\n", + "def attention(*args, **kwargs):\n", + " # number of input positions to remember in a cache when doing fast inference. \n", + " kwargs['predict_mem_len'] = 120\n", + " # number of input elements to drop once the fast inference input cache fills up.\n", + " kwargs['predict_drop_len'] = 120\n", + " # return the attention layer with the parameters defined above\n", + " return tl.SelfAttention(*args, **kwargs)\n", + "\n", + "# define the model using the ReformerLM function\n", + "model = ReformerLM(\n", + " vocab_size=33000,\n", + " n_layers=6,\n", + " mode='predict',\n", + " attention_type=attention,\n", + ")\n", + "\n", + "# define an input signature so we can initialize our model. shape will be (1, 1) and the data type is int32.\n", + "shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "# initialize from file\n", + "model.init_from_file('chatbot_model1.pkl.gz',\n", + " weights_only=True, input_signature=shape11)\n", + "\n", + "# save the starting state\n", + "STARTING_STATE = model.state" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize(sentence, vocab_file, vocab_dir):\n", + " return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]\n", + "\n", + "def detokenize(tokens, vocab_file, vocab_dir):\n", + " return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature):\n", + " \"\"\"\n", + " Args:\n", + " ReformerLM: the Reformer language model you just trained\n", + " start_sentence (string): starting sentence of the conversation\n", + " vocab_file (string): vocabulary filename\n", + " vocab_dir (string): directory of the vocabulary file\n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + "\n", + " Returns:\n", + " generator: yields the next symbol generated by the model\n", + " \"\"\"\n", + "\n", + " \n", + " # Create input tokens using the the tokenize function\n", + " input_tokens = tokenize(start_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", + " \n", + " # Add batch dimension to array. Convert from (n,) to (x, n) where \n", + " # x is the batch size. Default is 1. \n", + " input_tokens_with_batch = np.array(input_tokens)[None, :]\n", + " \n", + " # call the autoregressive_sample_stream function from trax\n", + " output_gen = trax.supervised.decoding.autoregressive_sample_stream( \n", + " # model\n", + " ReformerLM,\n", + " # inputs will be the tokens with batch dimension\n", + " inputs=input_tokens_with_batch,\n", + " # temperature\n", + " temperature=temperature\n", + " )\n", + "\n", + " \n", + " return output_gen" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)\n", + "\n", + "def attention(*args, **kwargs):\n", + " kwargs['predict_mem_len'] = 120 # max length for predictions\n", + " kwargs['predict_drop_len'] = 120 # never drop old stuff\n", + " return tl.SelfAttention(*args, **kwargs)\n", + "\n", + "model = ReformerLM(\n", + " vocab_size=33000,\n", + " n_layers=6,\n", + " mode='predict',\n", + " attention_type=attention,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "model.init_from_file('chatbot_model1.pkl.gz',\n", + " weights_only=True, input_signature=shape11)\n", + "\n", + "STARTING_STATE = model.state" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):\n", + " \"\"\"\n", + " Args:\n", + " ReformerLM: the Reformer language model you just trained\n", + " model_state (np.array): initial state of the model before decoding\n", + " start_sentence (string): starting sentence of the conversation\n", + " vocab_file (string): vocabulary filename\n", + " vocab_dir (string): directory of the vocabulary file\n", + " max_len (int): maximum number of tokens to generate \n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + "\n", + " Returns:\n", + " generator: yields the next symbol generated by the model\n", + " \"\"\" \n", + " \n", + " # define the delimiters we used during training\n", + " delimiter_1 = 'Person 1: ' \n", + " delimiter_2 = 'Person 2: '\n", + " \n", + " # initialize detokenized output\n", + " sentence = ''\n", + " \n", + " # token counter\n", + " counter = 0\n", + " \n", + " # output tokens. we insert a ': ' for formatting\n", + " result = [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]\n", + " \n", + " # reset the model state when starting a new dialogue\n", + " ReformerLM.state = model_state\n", + " \n", + " # calls the output generator implemented earlier\n", + " output = ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)\n", + " \n", + " # print the starting sentence\n", + " print(start_sentence.split(delimiter_2)[0].strip())\n", + " \n", + " # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.\n", + " for o in output:\n", + " \n", + " result.append(o)\n", + " \n", + " sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)\n", + " \n", + " if sentence.endswith(delimiter_1):\n", + " sentence = sentence.split(delimiter_1)[0]\n", + " print(f'{delimiter_2}{sentence}')\n", + " sentence = ''\n", + " result.clear()\n", + " \n", + " elif sentence.endswith(delimiter_2):\n", + " sentence = sentence.split(delimiter_2)[0]\n", + " print(f'{delimiter_1}{sentence}')\n", + " sentence = ''\n", + " result.clear()\n", + "\n", + " counter += 1\n", + " \n", + " if counter > max_len:\n", + " break \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Person 1: Are there theatres in town?\n", + "Person 2: : There are 4 theatres in town. Do you have a preference? \n", + "Person 1: Not really, but I would like the one in the south. \n", + "Person 2: I have one theatre, the Junction, and the other is Tenpin. \n", + "Person 1: Could I get the address and postcode? \n", + "Person 1: The postcode is cb17gx, and the address is Clifton Way, Cambridge Leisure Park, Clifton Way. Is there anything else I can i give for? \n" + ] + } + ], + "source": [ + "sample_sentence = ' Person 1: Are there theatres in town? Person 2: '\n", + "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Person 1: Is there a hospital nearby?\n", + "Person 2: : Addensbrookes Hospital is located at Hills Rd, Cambridge, postcode CB20QQ. Do you need a particular department? \n", + "Person 1: No, I just need the phone number, please. \n", + "Person 2: The phone number is 01223245151. \n", + "Person 1: Thank you. That's all I need. \n", + "Person 2: Thank you for using our services.Goodbye.\n", + "Person 1: Goodbye. \n" + ] + } + ], + "source": [ + "sample_sentence = ' Person 1: Is there a hospital nearby? Person 2: '\n", + "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Person 1: Can you book a taxi?\n", + "Person 2: : I sure can. When would you like to leave? \n", + "Person 1: I need to leave after 13:00. \n", + "Person 2: I'd be happy to help with your request, first I will need to know your destination. \n", + "Person 1: I'm going to be going to be from the city stop restaurant. \n", + "Person 2: Booking completed! Booked car type\t:\tgrey volkswagen\n", + "Contact number\t:\t07262372\n", + " \n", + "Person 2: Thank bybybybyby\n" + ] + } + ], + "source": [ + "sample_sentence = ' Person 1: Can you book a taxi? Person 2: '\n", + "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)" + ] + } + ], + "metadata": { + "coursera": { + "schema_names": [ + "NLPC4-4" + ] + }, + "jupytext": { + "encoding": "# -*- coding: utf-8 -*-", + "formats": "ipynb,py:percent", + "text_representation": { + "extension": ".py", + "format_name": "percent", + "format_version": "1.3", + "jupytext_version": "1.5.2" + } + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}