diff --git a/commit/solvers.py b/commit/solvers.py index 264afbe1..9cf4c6d6 100755 --- a/commit/solvers.py +++ b/commit/solvers.py @@ -11,9 +11,10 @@ import warnings eps = np.finfo(float).eps - # removed, for now, projection_onto_l2_ball -from commit.proximals import non_negativity, omega_group_sparsity, prox_group_sparsity, soft_thresholding list_regularizers = [None, 'sparsity', 'group_sparsity'] +from commit.proximals import non_negativity, omega_group_sparsity, prox_group_sparsity, soft_thresholding +# removed, for now, projection_onto_l2_ball + def init_regularisation( commit_evaluation, @@ -129,41 +130,38 @@ def init_regularisation( def regularisation2omegaprox(regularisation): - lambdaIC = float(regularisation.get('lambdaIC')) - lambdaEC = float(regularisation.get('lambdaEC')) - lambdaISO = float(regularisation.get('lambdaISO')) + lambdaIC = regularisation.get('lambdaIC') + lambdaEC = regularisation.get('lambdaEC') + lambdaISO = regularisation.get('lambdaISO') if lambdaIC<0.0 or lambdaEC<0.0 or lambdaISO<0.0: raise ValueError('Negative regularisation strengths are not allowed') - regIC = regularisation.get('regIC') - regEC = regularisation.get('regEC') - regISO = regularisation.get('regISO') - if not regIC in list_regularizers: + if not regularisation['regIC'] in list_regularizers: raise ValueError('Regularizer for the IC compartment not implemented') - if not regEC in list_regularizers: + if not regularisation['regEC'] in list_regularizers: raise ValueError('Regularizer for the EC compartment not implemented') - if not regISO in list_regularizers: + if not regularisation['regISO'] in list_regularizers: raise ValueError('Regularizer for the ISO compartment not implemented') # Intracellular Compartment startIC = regularisation.get('startIC') sizeIC = regularisation.get('sizeIC') - if regIC is None: + if regularisation['regIC'] is None: omegaIC = lambda x: 0.0 if regularisation.get('nnIC')==True: proxIC = lambda x, _: non_negativity(x,startIC,sizeIC) else: proxIC = lambda x, _: x - elif regIC == 'sparsity': + elif regularisation['regIC'] == 'sparsity': omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC],1) if regularisation.get('nnIC'): proxIC = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaIC,startIC,sizeIC),startIC,sizeIC) else: proxIC = lambda x, _: non_negativity(x,startIC,sizeIC) - # elif regIC == 'l2': + # elif regularisation['regIC'] == 'smoothness': # omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC]) # proxIC = lambda x: projection_onto_l2_ball(x, lambdaIC, startIC, sizeIC) - elif regIC == 'group_sparsity': + elif regularisation['regIC'] == 'group_sparsity': structureIC = regularisation.get('structureIC') groupWeightIC = regularisation.get('weightsIC') if not len(structureIC) == len(groupWeightIC): @@ -189,38 +187,38 @@ def regularisation2omegaprox(regularisation): # Extracellular Compartment startEC = regularisation.get('startEC') sizeEC = regularisation.get('sizeEC') - if regEC is None: + if regularisation['regIC'] is None: omegaEC = lambda x: 0.0 if regularisation.get('nnEC')==True: proxEC = lambda x, _: non_negativity(x,startEC,sizeEC) else: proxEC = lambda x, _: x - elif regEC == 'sparsity': + elif regularisation['regIC'] == 'sparsity': omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:(startEC+sizeEC)],1) if regularisation.get('nnEC'): proxEC = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaEC,startEC,sizeEC),startEC,sizeEC) else: proxEC = lambda x, scaling: soft_thresholding(x,scaling*lambdaEC,startEC,sizeEC) - # elif regEC == 'l2': + # elif regularisation['regIC'] == 'smoothness': # omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:(startEC+sizeEC)]) # proxEC = lambda x: projection_onto_l2_ball(x, lambdaEC, startEC, sizeEC) # Isotropic Compartment startISO = regularisation.get('startISO') sizeISO = regularisation.get('sizeISO') - if regISO is None: + if regularisation['regISO'] is None: omegaISO = lambda x: 0.0 if regularisation.get('nnISO')==True: proxISO = lambda x, _: non_negativity(x,startISO,sizeISO) else: proxISO = lambda x, _: x - elif regISO == 'sparsity': + elif regularisation['regISO'] == 'sparsity': omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)],1) if regularisation.get('nnISO'): proxISO = lambda x, scaling: non_negativity(soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO),startISO,sizeISO) else: proxISO = lambda x, scaling: soft_thresholding(x,scaling*lambdaISO,startISO,sizeISO) - # elif regISO == 'l2': + # elif regularisation['regISO'] == 'l2': # omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:(startISO+sizeISO)]) # proxISO = lambda x: projection_onto_l2_ball(x, lambdaISO, startISO, sizeISO)