Skip to content

Commit

Permalink
Renamed parameters for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
daducci committed Dec 4, 2023
1 parent 829f06b commit 280f327
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import warnings
eps = np.finfo(float).eps

# removed, for now, projection_onto_l2_ball
from commit.proximals import non_negativity, omega_group_sparsity, prox_group_sparsity, soft_thresholding
list_regularizers = [None, 'sparsity', 'group_sparsity']
from commit.proximals import non_negativity, omega_group_sparsity, prox_group_sparsity, soft_thresholding
# removed, for now, projection_onto_l2_ball


def init_regularisation(
commit_evaluation,
Expand Down Expand Up @@ -129,41 +130,38 @@ def init_regularisation(


def regularisation2omegaprox(regularisation):
lambdaIC = float(regularisation.get('lambdaIC'))
lambdaEC = float(regularisation.get('lambdaEC'))
lambdaISO = float(regularisation.get('lambdaISO'))
lambdaIC = regularisation.get('lambdaIC')
lambdaEC = regularisation.get('lambdaEC')
lambdaISO = regularisation.get('lambdaISO')
if lambdaIC<0.0 or lambdaEC<0.0 or lambdaISO<0.0:
raise ValueError('Negative regularisation strengths are not allowed')

regIC = regularisation.get('regIC')
regEC = regularisation.get('regEC')
regISO = regularisation.get('regISO')
if not regIC in list_regularizers:
if not regularisation['regIC'] in list_regularizers:
raise ValueError('Regularizer for the IC compartment not implemented')
if not regEC in list_regularizers:
if not regularisation['regEC'] in list_regularizers:
raise ValueError('Regularizer for the EC compartment not implemented')
if not regISO in list_regularizers:
if not regularisation['regISO'] in list_regularizers:
raise ValueError('Regularizer for the ISO compartment not implemented')

# Intracellular Compartment
startIC = regularisation.get('startIC')
sizeIC = regularisation.get('sizeIC')
if regIC is None:
if regularisation['regIC'] is None:
omegaIC = lambda x: 0.0
if regularisation.get('nnIC')==True:
proxIC = lambda x, _: non_negativity(x,startIC,sizeIC)
else:
proxIC = lambda x, _: x
elif regIC == 'sparsity':
elif regularisation['regIC'] == 'sparsity':
omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC],1)
if regularisation.get('nnIC'):
proxIC = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaIC,startIC,sizeIC),startIC,sizeIC)
else:
proxIC = lambda x, _: non_negativity(x,startIC,sizeIC)
# elif regIC == 'l2':
# elif regularisation['regIC'] == 'smoothness':
# omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC])
# proxIC = lambda x: projection_onto_l2_ball(x, lambdaIC, startIC, sizeIC)
elif regIC == 'group_sparsity':
elif regularisation['regIC'] == 'group_sparsity':
structureIC = regularisation.get('structureIC')
groupWeightIC = regularisation.get('weightsIC')
if not len(structureIC) == len(groupWeightIC):
Expand All @@ -189,38 +187,38 @@ def regularisation2omegaprox(regularisation):
# Extracellular Compartment
startEC = regularisation.get('startEC')
sizeEC = regularisation.get('sizeEC')
if regEC is None:
if regularisation['regIC'] is None:
omegaEC = lambda x: 0.0
if regularisation.get('nnEC')==True:
proxEC = lambda x, _: non_negativity(x,startEC,sizeEC)
else:
proxEC = lambda x, _: x
elif regEC == 'sparsity':
elif regularisation['regIC'] == 'sparsity':
omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:(startEC+sizeEC)],1)
if regularisation.get('nnEC'):
proxEC = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaEC,startEC,sizeEC),startEC,sizeEC)
else:
proxEC = lambda x, scaling: soft_thresholding(x,scaling*lambdaEC,startEC,sizeEC)
# elif regEC == 'l2':
# elif regularisation['regIC'] == 'smoothness':
# omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:(startEC+sizeEC)])
# proxEC = lambda x: projection_onto_l2_ball(x, lambdaEC, startEC, sizeEC)

# Isotropic Compartment
startISO = regularisation.get('startISO')
sizeISO = regularisation.get('sizeISO')
if regISO is None:
if regularisation['regISO'] is None:
omegaISO = lambda x: 0.0
if regularisation.get('nnISO')==True:
proxISO = lambda x, _: non_negativity(x,startISO,sizeISO)
else:
proxISO = lambda x, _: x
elif regISO == 'sparsity':
elif regularisation['regISO'] == 'sparsity':
omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1)
if regularisation.get('nnISO'):
proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO)
else:
proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO)
# elif regISO == 'l2':
# elif regularisation['regISO'] == 'l2':
# omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)])
# proxISO = lambda x: projection_onto_l2_ball(x, lambdaISO, startISO, sizeISO)

Expand Down

0 comments on commit 280f327

Please sign in to comment.