Skip to content

Commit

Permalink
Merge pull request #12 from remigenet/beta
Browse files Browse the repository at this point in the history
Beta
  • Loading branch information
remigenet authored Dec 27, 2024
2 parents 0ef3fbb + 3363ab2 commit 859fb72
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tkat_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "tkat"
version = "0.2.0"
version = "0.2.3"
description = "Temporal KAN Transformer"
authors = [ "Rémi Genet", "Hugo Inzirillo"]
readme = "README.md"
packages = [{include = "tkat"}]

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
python = ">=3.9,<3.13"
keras = ">=3.0.0,<4.0"
keras_efficient_kan = "^0.1.4"
tkan = "^0.4.1"
keras_efficient_kan = "^0.1.9"
tkan = "^0.4.3"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand All @@ -30,5 +30,7 @@ testpaths = ["tests"]
filterwarnings = [
"ignore:Can't initialize NVML:UserWarning",
"ignore:jax.xla_computation is deprecated:DeprecationWarning",
"ignore::DeprecationWarning:jax._src.dtypes"
"ignore::DeprecationWarning:jax._src.dtypes",
"ignore:Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new:DeprecationWarning:importlib",
"ignore:Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new:DeprecationWarning:importlib",
]
129 changes: 108 additions & 21 deletions tkat/tkat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@


import os
BACKEND = 'jax'
os.environ['KERAS_BACKEND'] = BACKEND

import pytest
import keras
from keras import ops
from keras import backend
from keras import random



import keras
from keras import ops
from keras import Model, Input
from keras.layers import Layer, LSTM, Dense, Input, Add, LayerNormalization, Multiply, Reshape, Activation, TimeDistributed, Flatten, Lambda, MultiHeadAttention, Concatenate
from tkan import TKAN

@keras.utils.register_keras_serializable(name="AddAndNorm")
class AddAndNorm(Layer):
def __init__(self, **kwargs):
super(AddAndNorm, self).__init__(**kwargs)

def build(self, input_shape):
self.add_layer = Add()
self.add_layer.build(input_shape)
self.norm_layer = LayerNormalization()
self.norm_layer.build(self.add_layer.compute_output_shape(input_shape))

def call(self, inputs):
tmp = self.add_layer(inputs)
Expand All @@ -18,6 +37,12 @@ def call(self, inputs):
def compute_output_shape(self, input_shape):
return input_shape[0] # Assuming all input shapes are the same

def get_config(self):
config = super().get_config()
return config


@keras.utils.register_keras_serializable(name="GRN")
class Gate(Layer):
def __init__(self, hidden_layer_size = None, **kwargs):
super(Gate, self).__init__(**kwargs)
Expand All @@ -29,18 +54,26 @@ def build(self, input_shape):
self.hidden_layer_size = input_shape[-1]
self.dense_layer = Dense(self.hidden_layer_size)
self.gated_layer = Dense(self.hidden_layer_size, activation='sigmoid')
self.multiply = Multiply()
super(Gate, self).build(input_shape)
self.dense_layer.build(input_shape)
self.gated_layer.build(input_shape)

def call(self, inputs):
dense_output = self.dense_layer(inputs)
gated_output = self.gated_layer(inputs)
return self.multiply([dense_output, gated_output])
return ops.multiply(dense_output, gated_output)

def compute_output_shape(self, input_shape):
return input_shape[:-1] + (self.hidden_layer_size,)

def get_config(self):
config = super().get_config()
config.update({
'hidden_layer_size': self.hidden_layer_size,
})
return config


@keras.utils.register_keras_serializable(name="GRN")
class GRN(Layer):
def __init__(self, hidden_layer_size, output_size=None, **kwargs):
super(GRN, self).__init__(**kwargs)
Expand All @@ -51,19 +84,19 @@ def build(self, input_shape):
if self.output_size is None:
self.output_size = self.hidden_layer_size
self.skip_layer = Dense(self.output_size)
self.skip_layer.build(input_shape)

self.hidden_layer_1 = Dense(self.hidden_layer_size, activation='elu')
self.hidden_layer_1.build(input_shape)
self.hidden_layer_2 = Dense(self.hidden_layer_size)
self.hidden_layer_2.build((*input_shape[:2], self.hidden_layer_size))
self.gate_layer = Gate(self.output_size)
self.gate_layer.build((*input_shape[:2], self.hidden_layer_size))
self.add_and_norm_layer = AddAndNorm()
super(GRN, self).build(input_shape)
self.add_and_norm_layer.build([(*input_shape[:2], self.output_size),(*input_shape[:2], self.output_size)])

def call(self, inputs):
if self.skip_layer is None:
skip = inputs
else:
skip = self.skip_layer(inputs)

skip = self.skip_layer(inputs)
hidden = self.hidden_layer_1(inputs)
hidden = self.hidden_layer_2(hidden)
gating_output = self.gate_layer(hidden)
Expand All @@ -72,26 +105,41 @@ def call(self, inputs):
def compute_output_shape(self, input_shape):
return input_shape[:-1] + (self.output_size,)

def get_config(self):
config = super().get_config()
config.update({
'hidden_layer_size': self.hidden_layer_size,
'output_size': self.output_size,
})
return config


@keras.utils.register_keras_serializable(name="VariableSelectionNetwork")
class VariableSelectionNetwork(Layer):
def __init__(self, num_hidden, **kwargs):
super(VariableSelectionNetwork, self).__init__(**kwargs)
self.num_hidden = num_hidden

def build(self, input_shape):
_, time_steps, embedding_dim, num_inputs = input_shape
batch_size, time_steps, embedding_dim, num_inputs = input_shape
self.softmax = Activation('softmax')
self.num_inputs = num_inputs
self.flatten_dim = time_steps * embedding_dim * num_inputs
self.reshape_layer = Reshape(target_shape=[time_steps, embedding_dim * num_inputs])
self.reshape_layer.build(input_shape)
self.mlp_dense = GRN(hidden_layer_size = self.num_hidden, output_size=num_inputs)
self.mlp_dense.build((batch_size, time_steps, embedding_dim * num_inputs))
self.grn_layers = [GRN(self.num_hidden) for _ in range(num_inputs)]
for i in range(num_inputs):
self.grn_layers[i].build(input_shape[:3])
super(VariableSelectionNetwork, self).build(input_shape)

def call(self, inputs):
_, time_steps, embedding_dim, num_inputs = inputs.shape
flatten = self.reshape_layer(inputs)
# Variable selection weights
mlp_outputs = self.mlp_dense(flatten)
sparse_weights = Activation('softmax')(mlp_outputs)
sparse_weights = ops.softmax(mlp_outputs)
sparse_weights = ops.expand_dims(sparse_weights, axis=2)

# Non-linear Processing & weight application
Expand All @@ -101,20 +149,20 @@ def call(self, inputs):
trans_emb_list.append(grn_output)

transformed_embedding = ops.stack(trans_emb_list, axis=-1)
combined = Multiply()([sparse_weights, transformed_embedding])
combined = ops.multiply(sparse_weights, transformed_embedding)
temporal_ctx = ops.sum(combined, axis=-1)

return temporal_ctx

class RecurrentLayer(Layer):
def __init__(self, num_units, return_state=False, use_tkan=False, **kwargs):
super(RecurrentLayer, self).__init__(**kwargs)
layer_cls = TKAN if use_tkan else LSTM
self.layer = layer_cls(num_units, return_sequences=True, return_state=return_state)
def get_config(self):
config = super().get_config()
config.update({
'num_hidden': self.num_hidden,
})
return config

def call(self, inputs, initial_state=None):
return self.layer(inputs, initial_state = initial_state)

@keras.utils.register_keras_serializable(name="EmbeddingLayer")
class EmbeddingLayer(Layer):
def __init__(self, num_hidden, **kwargs):
super(EmbeddingLayer, self).__init__(**kwargs)
Expand All @@ -124,6 +172,8 @@ def build(self, input_shape):
self.dense_layers = [
Dense(self.num_hidden) for _ in range(input_shape[-1])
]
for i in range(input_shape[-1]):
self.dense_layers[i].build((*input_shape[:2], 1))
super(EmbeddingLayer, self).build(input_shape)

def call(self, inputs):
Expand All @@ -133,6 +183,42 @@ def call(self, inputs):
def compute_output_shape(self, input_shape):
return input_shape[:-1] + (self.num_hidden, input_shape[-1])

def get_config(self):
config = super().get_config()
config.update({
'num_hidden': self.num_hidden,
})
return config

@keras.utils.register_keras_serializable(name="RecurrentLayer")
class RecurrentLayer(Layer):
def __init__(self, num_units, return_state=False, use_tkan=False, **kwargs):
super(RecurrentLayer, self).__init__(**kwargs)
self.num_units = num_units
self.return_state = return_state
self.use_tkan = use_tkan

def build(self, input_shape):
layer_cls = TKAN if self.use_tkan else LSTM
self.layer = layer_cls(self.num_units, return_state=self.return_state, return_sequences=True)
self.layer.build(input_shape)
super(RecurrentLayer, self).build(input_shape)

def call(self, inputs, initial_state=None):
return self.layer(inputs, initial_state = initial_state)

def compute_output_shape(self, input_shape):
return self.layer.compute_output_shape(input_shape)

def get_config(self):
config = super().get_config()
config.update({
'num_units': self.num_units,
'return_state': self.return_state,
'use_tkan': self.use_tkan,
})
return config


def TKAT(sequence_length: int, num_unknow_features: int, num_know_features: int, num_embedding: int, num_hidden: int, num_heads: int, n_ahead: int, use_tkan: bool = True):
"""Temporal Kan Transformer model
Expand All @@ -154,7 +240,7 @@ def TKAT(sequence_length: int, num_unknow_features: int, num_know_features: int,
inputs = Input(shape=(sequence_length+n_ahead, num_unknow_features + num_know_features))

embedded_inputs = EmbeddingLayer(num_embedding, name = 'embedding_layer')(inputs)

past_features = Lambda(lambda x: x[:, :sequence_length, :, :], name='past_observed_and_known')(embedded_inputs)
variable_selection_past = VariableSelectionNetwork(num_hidden, name='vsn_past_features')(past_features)

Expand All @@ -164,7 +250,7 @@ def TKAT(sequence_length: int, num_unknow_features: int, num_know_features: int,
#recurrent encoder-decoder
encode_out, *encode_states = RecurrentLayer(num_hidden, return_state = True, use_tkan = use_tkan, name='encoder')(variable_selection_past)
decode_out = RecurrentLayer(num_hidden, return_state = False, use_tkan = use_tkan, name='decoder')(variable_selection_future, initial_state = encode_states)

# all encoder-decod er history
history = Concatenate(axis=1)([encode_out, decode_out])

Expand All @@ -183,3 +269,4 @@ def TKAT(sequence_length: int, num_unknow_features: int, num_know_features: int,
dense_output = Dense(n_ahead)(flattened_output)

return Model(inputs=inputs, outputs=dense_output)

0 comments on commit 859fb72

Please sign in to comment.