Skip to content

Commit

Permalink
Merge pull request #1690 from microsoft/bug/doc
Browse files Browse the repository at this point in the history
Fixed readthedocs bug and added SASRec and SSEPT documentation
  • Loading branch information
miguelgfierro authored Mar 31, 2022
2 parents 8e632dd + 88a2210 commit 7721857
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 40 deletions.
1 change: 1 addition & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build:
- cmake

# Explicitly set the version of Python and its requirements
# The flat extra_requirements all is equivalent to: pip install .[all]
python:
version: "3.7"
install:
Expand Down
19 changes: 19 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,25 @@ SAR
.. automodule:: recommenders.models.sar.sar_singlenode
:members:

SASRec
******************************

.. automodule:: recommenders.models.sasrec.model
:members:

.. automodule:: recommenders.models.sasrec.sampler
:members:

.. automodule:: recommenders.models.sasrec.util
:members:


SSE-PT
******************************

.. automodule:: recommenders.models.sasrec.ssept
:members:


Surprise
******************************
Expand Down
166 changes: 147 additions & 19 deletions recommenders/models/sasrec/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import random
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import tensorflow as tf

from recommenders.utils.timer import Timer

Expand All @@ -16,6 +16,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
"""

def __init__(self, attention_dim, num_heads, dropout_rate):
"""Initialize parameters.
Args:
attention_dim (int): Dimension of the attention embeddings.
num_heads (int): Number of heads in the multi-head self-attention module.
dropout_rate (float): Dropout probability.
"""
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.attention_dim = attention_dim
Expand All @@ -30,6 +37,15 @@ def __init__(self, attention_dim, num_heads, dropout_rate):
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)

def call(self, queries, keys):
"""Model forward pass.
Args:
queries (tf.Tensor): Tensor of queries.
keys (tf.Tensor): Tensor of keys
Returns:
tf.Tensor: Output tensor.
"""

# Linear projections
Q = self.Q(queries) # (N, T_q, C)
Expand Down Expand Up @@ -108,6 +124,12 @@ class PointWiseFeedForward(tf.keras.layers.Layer):
"""

def __init__(self, conv_dims, dropout_rate):
"""Initialize parameters.
Args:
conv_dims (list): List of the dimensions of the Feedforward layer.
dropout_rate (float): Dropout probability.
"""
super(PointWiseFeedForward, self).__init__()
self.conv_dims = conv_dims
self.dropout_rate = dropout_rate
Expand All @@ -120,6 +142,14 @@ def __init__(self, conv_dims, dropout_rate):
self.dropout_layer = tf.keras.layers.Dropout(self.dropout_rate)

def call(self, x):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
Returns:
tf.Tensor: Output tensor.
"""

output = self.conv_layer1(x)
output = self.dropout_layer(output)
Expand Down Expand Up @@ -148,6 +178,16 @@ def __init__(
conv_dims,
dropout_rate,
):
"""Initialize parameters.
Args:
seq_max_len (int): Maximum sequence length.
embedding_dim (int): Embedding dimension.
attention_dim (int): Dimension of the attention embeddings.
num_heads (int): Number of heads in the multi-head self-attention module.
conv_dims (list): List of the dimensions of the Feedforward layer.
dropout_rate (float): Dropout probability.
"""
super(EncoderLayer, self).__init__()

self.seq_max_len = seq_max_len
Expand All @@ -167,6 +207,16 @@ def __init__(
)

def call_(self, x, training, mask):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
mask (tf.Tensor): Mask tensor.
Returns:
tf.Tensor: Output tensor.
"""

attn_output = self.mha(queries=self.layer_normalization(x), keys=x)
attn_output = self.dropout1(attn_output, training=training)
Expand All @@ -185,6 +235,16 @@ def call_(self, x, training, mask):
return out2

def call(self, x, training, mask):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
mask (tf.Tensor): Mask tensor.
Returns:
tf.Tensor: Output tensor.
"""

x_norm = self.layer_normalization(x)
attn_output = self.mha(queries=x_norm, keys=x)
Expand All @@ -210,6 +270,17 @@ def __init__(
conv_dims,
dropout_rate,
):
"""Initialize parameters.
Args:
num_layers (int): Number of layers.
seq_max_len (int): Maximum sequence length.
embedding_dim (int): Embedding dimension.
attention_dim (int): Dimension of the attention embeddings.
num_heads (int): Number of heads in the multi-head self-attention module.
conv_dims (list): List of the dimensions of the Feedforward layer.
dropout_rate (float): Dropout probability.
"""
super(Encoder, self).__init__()

self.num_layers = num_layers
Expand All @@ -229,6 +300,16 @@ def __init__(
self.dropout = tf.keras.layers.Dropout(dropout_rate)

def call(self, x, training, mask):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
mask (tf.Tensor): Mask tensor.
Returns:
tf.Tensor: Output tensor.
"""

for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
Expand All @@ -243,6 +324,13 @@ class LayerNormalization(tf.keras.layers.Layer):
"""

def __init__(self, seq_max_len, embedding_dim, epsilon):
"""Initialize parameters.
Args:
seq_max_len (int): Maximum sequence length.
embedding_dim (int): Embedding dimension.
epsilon (float): Epsilon value.
"""
super(LayerNormalization, self).__init__()
self.seq_max_len = seq_max_len
self.embedding_dim = embedding_dim
Expand All @@ -260,6 +348,14 @@ def __init__(self, seq_max_len, embedding_dim, epsilon):
)

def call(self, x):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
Returns:
tf.Tensor: Output tensor.
"""
mean, variance = tf.nn.moments(x, [-1], keepdims=True)
normalized = (x - mean) / ((variance + self.epsilon) ** 0.5)
output = self.gamma * normalized + self.beta
Expand All @@ -279,19 +375,22 @@ class SASREC(tf.keras.Model):
Original source code from nnkkmto/SASRec-tf2,
https://github.com/nnkkmto/SASRec-tf2
Args:
item_num: number of items in the dataset
seq_max_len: maximum number of items in user history
num_blocks: number of Transformer blocks to be used
embedding_dim: item embedding dimension
attention_dim: Transformer attention dimension
conv_dims: list of the dimensions of the Feedforward layer
dropout_rate: dropout rate
l2_reg: coefficient of the L2 regularization
num_neg_test: number of negative examples used in testing
"""

def __init__(self, **kwargs):
"""Model initialization.
Args:
item_num (int): Number of items in the dataset.
seq_max_len (int): Maximum number of items in user history.
num_blocks (int): Number of Transformer blocks to be used.
embedding_dim (int): Item embedding dimension.
attention_dim (int): Transformer attention dimension.
conv_dims (list): List of the dimensions of the Feedforward layer.
dropout_rate (float): Dropout rate.
l2_reg (float): Coefficient of the L2 regularization.
num_neg_test (int): Number of negative examples used in testing.
"""
super(SASREC, self).__init__()

self.item_num = kwargs.get("item_num", None)
Expand Down Expand Up @@ -336,6 +435,16 @@ def __init__(self, **kwargs):
)

def embedding(self, input_seq):
"""Compute the sequence and positional embeddings.
Args:
input_seq (tf.Tensor): Input sequence
Returns:
tf.Tensor, tf.Tensor:
- Sequence embeddings.
- Positional embeddings.
"""

seq_embeddings = self.item_embedding_layer(input_seq)
seq_embeddings = seq_embeddings * (self.embedding_dim ** 0.5)
Expand All @@ -348,10 +457,17 @@ def embedding(self, input_seq):
return seq_embeddings, positional_embeddings

def call(self, x, training):
"""
Returns the logits of the positive examples,
logits of the negative examples,
mask for nonzero targets
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
Returns:
tf.Tensor, tf.Tensor, tf.Tensor:
- Logits of the positive examples.
- Logits of the negative examples.
- Mask for nonzero targets
"""

input_seq = x["input_seq"]
Expand Down Expand Up @@ -409,8 +525,13 @@ def call(self, x, training):
return pos_logits, neg_logits, istarget

def predict(self, inputs):
"""
Returns the logits for the test items
"""Returns the logits for the test items.
Args:
inputs (tf.Tensor): Input tensor.
Returns:
tf.Tensor: Output tensor.
"""
training = False
input_seq = inputs["input_seq"]
Expand Down Expand Up @@ -442,10 +563,17 @@ def predict(self, inputs):
return test_logits

def loss_function(self, pos_logits, neg_logits, istarget):
"""
Losses are calculated separately for the positive and negative
"""Losses are calculated separately for the positive and negative
items based on the corresponding logits. A mask is included to
take care of the zero items (added for padding).
Args:
pos_logits (tf.Tensor): Logits of the positive examples.
neg_logits (tf.Tensor): Logits of the negative examples.
istarget (tf.Tensor): Mask for nonzero targets.
Returns:
float: Loss.
"""

pos_logits = pos_logits[:, 0]
Expand Down
1 change: 1 addition & 0 deletions recommenders/models/sasrec/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def sample_function(
):
"""Batch sampler that creates a sequence of negative items based on the
original sequence of items (positive) that the user has interacted with.
Args:
user_train (dict): dictionary of training exampled for each user
usernum (int): number of users
Expand Down
Loading

0 comments on commit 7721857

Please sign in to comment.