Skip to content

Commit

Permalink
Merge pull request #129 from daducci/fix_fista_update
Browse files Browse the repository at this point in the history
[FIX] Fista update
  • Loading branch information
nightwnvol authored Dec 6, 2023
2 parents 4edc2f2 + 280f327 commit 103c60d
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 217 deletions.
85 changes: 34 additions & 51 deletions commit/proximals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ cpdef soft_thresholding(double [::1] x, double lam, int compartment_start, int c
cdef:
int i
for i in xrange(compartment_start, compartment_start+compartment_size):
if x[i] <= lam:
x[i] = 0.0
else:
# if x[i] <= lam:
# x[i] = 0.0
# else:
# x[i] = x[i] - lam
if x[i] > lam:
x[i] = x[i] - lam
elif x[i] < -lam:
x[i] = x[i] + lam
else:
x[i] = 0.0
return np.asarray( x )


Expand Down Expand Up @@ -71,29 +77,18 @@ cpdef omega_group_sparsity(double [::1] x, int [::1] group_idx, int [::1] group_
double omega = 0.0, gNorm, x_i

if lam != 0:
if n == 2:
for k in xrange(nG):
N = group_size[k]
gNorm = 0.0
for i in xrange(j,j+N) :
x_i = x[group_idx[i]]
gNorm += x_i*x_i
omega += group_weight[k] * sqrt( gNorm )
j += N
elif n == np.inf:
for k in xrange(nG):
N = group_size[k]
gNorm = x[group_idx[j]]
for i in xrange(j+1,j+N) :
x_i = x[group_idx[i]]
if x_i > gNorm :
gNorm = x_i
omega += group_weight[k] * gNorm
j += N
for k in xrange(nG):
N = group_size[k]
gNorm = 0.0
for i in xrange(j,j+N) :
x_i = x[group_idx[i]]
gNorm += x_i*x_i
omega += group_weight[k] * sqrt( gNorm )
j += N
return lam*omega


cpdef prox_group_sparsity( double [::1] x, int [::1] group_idx, int [::1] group_size, double [::1] group_weight, double lam, double n ) :
cpdef prox_group_sparsity( double [::1] x, int [::1] group_idx, int [::1] group_size, double [::1] group_weight, double lam) :
"""
References:
[1] Jenatton et al. - `Proximal Methods for Hierarchical Sparse Coding`
Expand All @@ -109,33 +104,21 @@ cpdef prox_group_sparsity( double [::1] x, int [::1] group_idx, int [::1] group_
x[i] = 0.0

if lam != 0:
if n == 2 :
for k in xrange(nG) :
N = group_size[k]
gNorm = 0.0
for i in xrange(j,j+N) :
x_i = x[group_idx[i]]
gNorm += x_i*x_i
gNorm = sqrt( gNorm )
for k in xrange(nG) :
N = group_size[k]
gNorm = 0.0
for i in xrange(j,j+N) :
x_i = x[group_idx[i]]
gNorm += x_i*x_i
gNorm = sqrt( gNorm )

wl = group_weight[k] * lam
if gNorm <= wl :
for i in xrange(j,j+N) :
x[ group_idx[i] ] = 0.0
else :
wl = (gNorm-wl)/gNorm
for i in xrange(j,j+N) :
x[ group_idx[i] ] *= wl
j += N
# elif n == np.inf :
# [TODO] TO be correctly implemented
# for k in range(nG) :
# idx = subtree[k]
# # xn = max( v[idx] )
# r = weight[k] * lam
# for i in idx :
# if v[i] <= r:
# v[i] = 0.0
# else :
# v[i] -= r
wl = group_weight[k] * lam
if gNorm <= wl :
for i in xrange(j,j+N) :
x[ group_idx[i] ] = 0.0
else :
wl = (gNorm-wl)/gNorm
for i in xrange(j,j+N) :
x[ group_idx[i] ] *= wl
j += N
return np.asarray( x )
Loading

0 comments on commit 103c60d

Please sign in to comment.