-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultilingual_chatbot_arena
1 lines (1 loc) · 10.6 KB
/
multilingual_chatbot_arena
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":86946,"databundleVersionId":10131489,"sourceType":"competition"},{"sourceId":8897601,"sourceType":"datasetVersion","datasetId":5297895},{"sourceId":166245,"sourceType":"modelInstanceVersion","isSourceIdPinned":true,"modelInstanceId":141458,"modelId":164048}],"dockerImageVersionId":30822,"isInternetEnabled":false,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"<a href=\"https://www.kaggle.com/code/sabra15/multilingual-chatbot-arena-qwen2-5?scriptVersionId=218291072\" target=\"_blank\"><img align=\"left\" alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>","metadata":{},"cell_type":"markdown"},{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\n# import os\n# for dirname, _, filenames in os.walk('/kaggle/input'):\n# for filename in filenames:\n# print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:06.511072Z","iopub.execute_input":"2025-01-05T02:26:06.511372Z","iopub.status.idle":"2025-01-05T02:26:06.832757Z","shell.execute_reply.started":"2025-01-05T02:26:06.511349Z","shell.execute_reply":"2025-01-05T02:26:06.831923Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Install Dependencies","metadata":{}},{"cell_type":"code","source":"!pip install peft \\\n -U --no-index --find-links /kaggle/input/lmsys-wheel-files --quiet","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:09.928648Z","iopub.execute_input":"2025-01-05T02:26:09.929092Z","iopub.status.idle":"2025-01-05T02:26:15.568159Z","shell.execute_reply.started":"2025-01-05T02:26:09.929057Z","shell.execute_reply":"2025-01-05T02:26:15.567016Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Library Imports","metadata":{}},{"cell_type":"code","source":"from dataclasses import dataclass\nfrom peft import PeftModel\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nimport polars as pl\nimport time\nimport torch","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:15.569569Z","iopub.execute_input":"2025-01-05T02:26:15.569794Z","iopub.status.idle":"2025-01-05T02:26:21.64773Z","shell.execute_reply.started":"2025-01-05T02:26:15.569776Z","shell.execute_reply":"2025-01-05T02:26:21.647019Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Configurations","metadata":{}},{"cell_type":"code","source":"@dataclass\nclass Config:\n qwen_dir='/kaggle/input/qwen2.5/transformers/1.5b-instruct/1'\n max_length=2048\n device=torch.device('cuda')\n\ncfg = Config()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:21.649686Z","iopub.execute_input":"2025-01-05T02:26:21.650104Z","iopub.status.idle":"2025-01-05T02:26:21.654492Z","shell.execute_reply.started":"2025-01-05T02:26:21.650082Z","shell.execute_reply":"2025-01-05T02:26:21.653612Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Load Data","metadata":{}},{"cell_type":"code","source":"def load_data(file_path):\n return pl.read_parquet(file_path).to_pandas()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:21.655543Z","iopub.execute_input":"2025-01-05T02:26:21.65586Z","iopub.status.idle":"2025-01-05T02:26:21.670841Z","shell.execute_reply.started":"2025-01-05T02:26:21.655828Z","shell.execute_reply":"2025-01-05T02:26:21.669969Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Preprocess Data","metadata":{}},{"cell_type":"code","source":"def preprocess_data(df):\n df.prompt = df.prompt.fillna('')\n df.response_a = df.response_a.fillna('')\n df.response_b = df.response_b.fillna('')\n return df","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:21.671931Z","iopub.execute_input":"2025-01-05T02:26:21.672263Z","iopub.status.idle":"2025-01-05T02:26:21.685076Z","shell.execute_reply.started":"2025-01-05T02:26:21.672228Z","shell.execute_reply":"2025-01-05T02:26:21.684118Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Initialize Model and Tokenizer","metadata":{}},{"cell_type":"code","source":"tokenizer = AutoTokenizer.from_pretrained(cfg.qwen_dir)\nmodel = AutoModelForCausalLM.from_pretrained(cfg.qwen_dir)\nmodel.eval()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:21.685846Z","iopub.execute_input":"2025-01-05T02:26:21.686169Z","iopub.status.idle":"2025-01-05T02:26:39.406265Z","shell.execute_reply.started":"2025-01-05T02:26:21.686142Z","shell.execute_reply":"2025-01-05T02:26:39.405396Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Tokenise Data","metadata":{}},{"cell_type":"code","source":"instruction = \"\"\"In the text provided for you below, PROMPT is the question presented; MODEL_A is the response from the first model; MODEL_B is the response from the second model. Please select the best answer from the two responses provided. If the first answer is better, return \"model_a\"; if the second answer is better, return \"model_b\".\"\"\"\n\ndef tokenize_data(df):\n tokenised_data = []\n ids = []\n\n for idx in range(len(df)):\n\n rec = df.iloc[idx,:]\n\n prompt = 'PROMPT: ' + rec['prompt']\n model_a = 'MODEL_A: ' + rec['response_a']\n model_b = 'MODEL_B: ' + rec['response_b']\n text = prompt + model_a + model_b\n\n messages = [\n {\"role\": \"system\", \"content\": instruction},\n {\"role\": \"user\", \"content\": text}\n ]\n\n text = tokenizer.apply_chat_template(\n messages,\n tokenize=False,\n add_generation_prompt=True\n )\n \n tokenised_datum = tokenizer(\n [text],\n return_tensors=\"pt\",\n padding=True,\n truncation=True\n ).to(cfg.device)\n\n tokenised_data.append(tokenised_datum)\n ids.append(rec['id'])\n\n return tokenised_data, ids","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:39.407651Z","iopub.execute_input":"2025-01-05T02:26:39.407917Z","iopub.status.idle":"2025-01-05T02:26:39.413166Z","shell.execute_reply.started":"2025-01-05T02:26:39.40787Z","shell.execute_reply":"2025-01-05T02:26:39.412259Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Method to Make Prediction","metadata":{}},{"cell_type":"code","source":"def inference(tokenised_datum):\n model.to(cfg.device)\n input_ids = tokenised_datum.input_ids\n\n generated_ids = model.generate(\n input_ids,\n max_new_tokens=5\n )\n\n generated_ids = [\n output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)\n ]\n \n response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n\n return response","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:39.414379Z","iopub.execute_input":"2025-01-05T02:26:39.414697Z","iopub.status.idle":"2025-01-05T02:26:39.430618Z","shell.execute_reply.started":"2025-01-05T02:26:39.414665Z","shell.execute_reply":"2025-01-05T02:26:39.429851Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Train Model","metadata":{}},{"cell_type":"code","source":"import os\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\ntrain = load_data('/kaggle/input/wsdm-cup-multilingual-chatbot-arena/train.parquet')\ntrain = train.head(10000)\ntrain = preprocess_data(train)\ntokenised_data_train, ids = tokenize_data(train)\npredictions_train = []\n\nfor tokenised_datum_train in tokenised_data_train:\n try:\n response = inference(tokenised_datum_train)\n predictions_train.append(response)\n except Exception as e:\n print(f\"An error occurred: {e}\")\n predictions_train.append(\"model_a\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-05T02:26:39.431394Z","iopub.execute_input":"2025-01-05T02:26:39.431617Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Make Prediction","metadata":{}},{"cell_type":"code","source":"test = load_data('/kaggle/input/wsdm-cup-multilingual-chatbot-arena/test.parquet')\ntest = preprocess_data(test)\ntokenised_data, ids = tokenize_data(test)\npredictions = []\n\nfor tokenised_datum in tokenised_data:\n try:\n response = inference(tokenised_datum)\n predictions.append(response)\n except Exception as e:\n print(f\"An error occurred: {e}\")\n predictions.append(\"model_a\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-04T16:56:26.67811Z","iopub.execute_input":"2025-01-04T16:56:26.6784Z","iopub.status.idle":"2025-01-04T16:56:28.789227Z","shell.execute_reply.started":"2025-01-04T16:56:26.678379Z","shell.execute_reply":"2025-01-04T16:56:28.788515Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"submission = pd.DataFrame({\n 'id': ids,\n 'winner': predictions\n})\nsubmission.to_csv(\"submission.csv\",index=False)\nsubmission","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-04T16:56:29.325526Z","iopub.execute_input":"2025-01-04T16:56:29.32586Z","iopub.status.idle":"2025-01-04T16:56:29.335822Z","shell.execute_reply.started":"2025-01-04T16:56:29.325835Z","shell.execute_reply":"2025-01-04T16:56:29.334968Z"}},"outputs":[],"execution_count":null}]}