Skip to content

Commit

Permalink
Merge pull request #1678 from microsoft/staging_abir_ssept
Browse files Browse the repository at this point in the history
Added ssept model path & citation
  • Loading branch information
anargyri authored Mar 25, 2022
2 parents 93c9df2 + 19edc07 commit a3627e7
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 68 deletions.
113 changes: 55 additions & 58 deletions examples/00_quick_start/sasrec_amazon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# SASRec\n",
"# SASRec & SSEPT\n",
"\n",
"### Self-Attentive Sequential Recommendation Using Transformer \\[1\\]\n",
"### Sequential Recommendation Using Transformer \\[1, 6\\] \n",
"\n",
"![image.png](attachment:image.png)\n",
"\n",
"This is a class of sequential recommendation that uses Transformer \\[2\\] for encoding the users preference represented in terms of a sequence of items purchased/viewed before. Instead of using CNN (Caser \\[3\\]) or RNN (GRU4Rec \\[4\\], SLI-Rec \\[5\\] etc.) the approach relies on Transformer based encoder that generates a new representation of the item sequence. This notebook provides an example of necessary steps to train and test a SASRec model. "
"This is a class of sequential recommendation that uses Transformer \\[2\\] for encoding the users preference represented in terms of a sequence of items purchased/viewed before. Instead of using CNN (Caser \\[3\\]) or RNN (GRU4Rec \\[4\\], SLI-Rec \\[5\\] etc.) the approach relies on Transformer based encoder that generates a new representation of the item sequence. Two variants of this Transformer based approaches are included here, \n",
"\n",
"- Self-Attentive Sequential Recommendation (or SASRec [1]) that is based on vanilla Transformer and models only the item sequence and\n",
"- Stochastic Shared Embedding based Personalized Transformer or SSE-PT [6], that also models the users along with the items. \n",
"\n",
"This notebook provides an example of necessary steps to train and test either a SASRec or a SSE-PT model. "
]
},
{
Expand All @@ -30,16 +35,24 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-22 10:24:04.916252: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/intel/compilers_and_libraries_2018.1.163/linux/tbb/lib/intel64_lin/gcc4.7:/opt/intel/compilers_and_libraries_2018.1.163/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2018.1.163/linux/mkl/lib/intel64_lin::/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/:/opt/gurobi902/linux64/lib:/opt/gurobi902/linux64/lib\n",
"2022-03-22 10:24:04.916292: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"System version: 3.7.11 (default, Jul 27 2021, 14:32:16) \n",
"[GCC 7.5.0]\n",
"Tensorflow version: 2.7.1\n"
"Tensorflow version: 2.8.0\n"
]
}
],
Expand All @@ -58,10 +71,11 @@
"\n",
"from recommenders.utils.timer import Timer\n",
"from recommenders.datasets.amazon_reviews import get_review_data\n",
"from recommenders.datasets.split_utils import min_rating_filter_pandas\n",
"from recommenders.datasets.split_utils import filter_k_core\n",
"\n",
"# Transformer Based Models\n",
"from recommenders.models.sasrec.model import SASREC\n",
"from recommenders.models.sasrec.ssept import SSEPT\n",
"\n",
"# Sampler for sequential prediction\n",
"from recommenders.models.sasrec.sampler import WarpSampler\n",
Expand All @@ -80,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {
"tags": [
"parameters"
Expand All @@ -105,16 +119,16 @@
"num_heads = 1 # number of attention heads\n",
"dropout_rate = 0.1 # dropout rate\n",
"l2_emb = 0.0 # L2 regularization coefficient\n",
"num_neg_test = 100 # number of negative examples per positive example"
"num_neg_test = 100 # number of negative examples per positive example\n",
"model_name = 'ssept' # 'sasrec' or 'ssept'"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model_name = 'sasrec' # 'sasrec' or 'ssept'\n",
"reviews_name = dataset + '.json'\n",
"outfile = dataset + '.txt'\n",
"\n",
Expand All @@ -127,42 +141,22 @@
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def filter_K_core(data, core_num=0, col_user=\"userID\", col_item=\"itemID\"):\n",
" \"\"\"Filter rating dataframe for minimum number of users and items by \n",
" repeatedly applying min_rating_filter until the condition is satisfied. \n",
" \n",
" \"\"\"\n",
" num_users, num_items = len(data[col_user].unique()), len(data[col_item].unique())\n",
" print(f\"Original: {num_users} users and {num_items} items\")\n",
" df = data.copy()\n",
"\n",
" if core_num > 0:\n",
" while True:\n",
" df = min_rating_filter_pandas(df, min_rating=core_num, filter_by=\"item\")\n",
" df = min_rating_filter_pandas(df, min_rating=core_num, filter_by=\"user\")\n",
" count_u = df.groupby(col_user)[col_item].count()\n",
" count_i = df.groupby(col_item)[col_user].count()\n",
" if len(count_i[count_i < core_num]) == 0 and len(count_u[count_u < core_num]) == 0:\n",
" break\n",
" df = df.sort_values(by=[col_user])\n",
" print(f\"Final: {len(df[col_user].unique())} users and {len(df[col_item].unique())} items\")\n",
"\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original: 192403 users and 63001 items\n",
"Final: 20247 users and 11589 items\n"
]
}
],
"source": [
"if not os.path.exists(os.path.join(data_dir, outfile)):\n",
" df = pd.read_csv(reviews_output, sep=\"\\t\", names=[\"userID\", \"itemID\", \"time\"])\n",
" df = filter_K_core(df, 10)\n",
" df = filter_k_core(df, 10) # filter for users & items with less than 10 interactions\n",
" \n",
" user_set, item_set = set(df['userID'].unique()), set(df['itemID'].unique())\n",
" user_map = dict()\n",
Expand Down Expand Up @@ -226,16 +220,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"../../tests/resources/deeprec/sasrec/reviews_Electronics_5.txt\n",
"36262 Users and 35074 items\n",
"average sequence length: 15.01\n"
"20247 Users and 11589 items\n",
"average sequence length: 15.16\n"
]
}
],
Expand Down Expand Up @@ -264,7 +258,8 @@
"source": [
"### Model Creation\n",
"\n",
"Model parameters are \n",
"Model parameters are\n",
"\n",
" - number of items\n",
" - maximum sequence length of the user interaction history\n",
" - number of Transformer blocks\n",
Expand All @@ -278,7 +273,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -300,12 +295,12 @@
" seq_max_len=maxlen,\n",
" num_blocks=num_blocks,\n",
" # embedding_dim=hidden_units, # optional\n",
" user_embedding_dim=hidden_units,\n",
" user_embedding_dim=10,\n",
" item_embedding_dim=hidden_units,\n",
" attention_dim=hidden_units,\n",
" attention_num_heads=num_heads,\n",
" dropout_rate=dropout_rate,\n",
" conv_dims = [200, 200],\n",
" conv_dims = [110, 110],\n",
" l2_reg=l2_emb,\n",
" num_neg_test=num_neg_test\n",
" )\n",
Expand All @@ -326,7 +321,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -348,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -363,8 +358,8 @@
"output_type": "stream",
"text": [
"\n",
"epoch: 5, test (NDCG@10: 0.3410638795485906, HR@10: 0.5421)\n",
"Time cost for training is 7.36 mins\n"
"epoch: 5, test (NDCG@10: 0.3099896446332482, HR@10: 0.5142)\n",
"Time cost for training is 7.17 mins\n"
]
},
{
Expand All @@ -384,14 +379,14 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'ndcg@10': 0.3410638795485906, 'Hit@10': 0.5421}\n"
"{'ndcg@10': 0.3037326157112286, 'Hit@10': 0.5036}\n"
]
}
],
Expand Down Expand Up @@ -463,7 +458,9 @@
"\n",
"\\[4\\] Balázs Hidasi, Alexandros Karatzoglou, Linas Baltrunas, and Domonkos Tikk. 2015. Session-based recommendations with recurrent neural networks. arXiv preprint arXiv:1511.06939 (2015)\n",
"\n",
"\\[5\\] Zeping Yu, Jianxun Lian, Ahmad Mahmoody, Gongshen Liu, Xing Xie. Adaptive User Modeling with Long and Short-Term Preferences for Personailzed Recommendation. In Proceedings of the 28th International Joint Conferences on Artificial Intelligence, IJCAI’19, Pages 4213-4219. AAAI Press, 2019."
"\\[5\\] Zeping Yu, Jianxun Lian, Ahmad Mahmoody, Gongshen Liu, Xing Xie. Adaptive User Modeling with Long and Short-Term Preferences for Personailzed Recommendation. In Proceedings of the 28th International Joint Conferences on Artificial Intelligence, IJCAI’19, Pages 4213-4219. AAAI Press, 2019.\n",
"\n",
"\\[6\\] Liwei Wu, Shuqing Li, Cho-Jui Hsieh, James Sharpnack. SSE-PT: Sequential Recommendation Via Personalized Transformer. In Fourteenth ACM Conference on Recommender Systems, RecSys'20:, Pages 328–337, 2020."
]
}
],
Expand All @@ -473,9 +470,9 @@
"hash": "adf311e09e3d70e4b770d653e66a69805c21f44d471e9851e226c4ddc6ad9826"
},
"kernelspec": {
"display_name": "Python (reco-new)",
"display_name": "reco_gpu",
"language": "python",
"name": "python3"
"name": "reco_gpu"
},
"language_info": {
"codemirror_mode": {
Expand Down
35 changes: 35 additions & 0 deletions recommenders/datasets/split_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

import numpy as np
import math
import logging

from recommenders.utils.constants import DEFAULT_ITEM_COL, DEFAULT_USER_COL

logger = logging.getLogger(__name__)

try:
from pyspark.sql import functions as F, Window
except ImportError:
Expand Down Expand Up @@ -163,3 +166,35 @@ def split_pandas_data_with_ratios(data, ratios, seed=42, shuffle=False):
splits[i]["split_index"] = i

return splits


def filter_k_core(data, core_num=0, col_user="userID", col_item="itemID"):
"""Filter rating dataframe for minimum number of users and items by
repeatedly applying min_rating_filter until the condition is satisfied.
"""
num_users, num_items = len(data[col_user].unique()), len(data[col_item].unique())
logger.info("Original: %d users and %d items", num_users, num_items)
df_inp = data.copy()

if core_num > 0:
while True:
df_inp = min_rating_filter_pandas(
df_inp, min_rating=core_num, filter_by="item"
)
df_inp = min_rating_filter_pandas(
df_inp, min_rating=core_num, filter_by="user"
)
count_u = df_inp.groupby(col_user)[col_item].count()
count_i = df_inp.groupby(col_item)[col_user].count()
if (
len(count_i[count_i < core_num]) == 0
and len(count_u[count_u < core_num]) == 0
):
break
df_inp = df_inp.sort_values(by=[col_user])
num_users = len(df_inp[col_user].unique())
num_items = len(df_inp[col_item].unique())
logger.info("Final: %d users and %d items", num_users, num_items)

return df_inp
Loading

0 comments on commit a3627e7

Please sign in to comment.