Skip to content

Commit

Permalink
remove SURE
Browse files Browse the repository at this point in the history
  • Loading branch information
AnderBiguri committed Nov 25, 2024
1 parent 8b2878d commit 8806f41
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 209 deletions.
28 changes: 27 additions & 1 deletion LION/losses/SURE.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def __init__(self, noise_std: float, epsilon: Optional[float] = None) -> None:
warnings.warn(
"SURE expects Gaussian noise, which is not the case in noisy recosntruction of CT, so this may not work as expected"
)
raise NotImplementedError(
"This is not working as expected, it is not implemented for CT reconstruction. See issue #144 to develop this"
)

def forward(self, model, noisy):
def forward(self, noisy, model):

if model.get_input_type() != ModelInputType.IMAGE:
raise NotImplementedError(
Expand All @@ -40,3 +43,26 @@ def forward(self, model, noisy):
@staticmethod
def default_epsilon(y):
return torch.max(y) / 1000

@staticmethod
def cite(cite_format="MLA"):

if cite_format == "MLA":
print("Metzler, Christopher A., et al.")
print('"Unsupervised learning with Steins unbiased risk estimator."')
print("\x1B[3m arXiv preprint arXiv:1805.10531 \x1B[0m")
print("2018")

elif cite_format == "bib":
string = """
@article{metzler2018unsupervised,
title={Unsupervised learning with Stein's unbiased risk estimator},
author={Metzler, Christopher A and Mousavi, Ali and Heckel, Reinhard and Baraniuk, Richard G},
journal={arXiv preprint arXiv:1805.10531},
year={2018}
}"""
print(string)
else:
raise AttributeError(
'cite_format not understood, only "MLA" and "bib" supported'
)
8 changes: 4 additions & 4 deletions LION/optimizers/LIONsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def check_validation_ready(self, autofill=True, verbose=True):
"validation_fn",
expected_type=callable,
error=False,
autofill=True,
autofill=self.validation_loader is not None,
verbose=verbose,
default=self.loss_fn,
)
Expand All @@ -338,7 +338,7 @@ def check_validation_ready(self, autofill=True, verbose=True):
"validation_freq",
expected_type=int,
error=False,
autofill=autofill,
autofill=self.validation_loader is not None,
verbose=verbose,
default=10,
)
Expand Down Expand Up @@ -699,8 +699,6 @@ def epoch_step(self, epoch):
self.save_validation(epoch)
elif self.verbose:
print(f"Epoch {epoch+1} - Training loss: {self.train_loss[epoch]}")
elif self.validation_freq is not None and self.validation_loss is not None:
self.validation_loss[epoch] = self.validate()

def train(self, n_epochs):
"""
Expand All @@ -719,6 +717,8 @@ def train(self, n_epochs):

if self.check_validation_ready() == 0:
self.validation_loss = np.zeros((n_epochs))
if self.validation_loader is None:
self.validation_loss = None

self.model.train()
# train loop
Expand Down
11 changes: 11 additions & 0 deletions papers_in_LION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ Pelt, Daniël M., and James A. Sethian. "A mixed-scale dense convolutional neura
[https://doi.org/10.1073/pnas.1715832114](https://doi.org/10.1073/pnas.1715832114)

`LION/models/CNNs/MS-D/` Submodule with the original repo

`LION/models/CNNs/MSD_pytorch.py` the LIONmodel to load the original code

`LION/models/CNNs/MSDNet.py` Our version of the MSD_pytorch model. Uses more memory

#### Learned Primal Dual (LPD)
Expand Down Expand Up @@ -66,6 +68,7 @@ Kiss, Maximilian B., et al. "2DeteCT-A large 2D expandable, trainable, experimen
[https://doi.org/10.1038/s41597-023-02484-6](https://doi.org/10.1038/s41597-023-02484-6)

`LION/data_loaders/2deteCT/` Code to download and pre-process a LION version of the 2deteCT, made with the authors.

`LION/data_loaders/deteCT.py` Pytorch DataSet

#### LIDC-IDRI
Expand All @@ -74,8 +77,16 @@ Armato III, Samuel G., et al. "The lung image database consortium (LIDC) and ima
[https://doi.org/10.1118/1.3528204](https://doi.org/10.1118/1.3528204)

`LION/data_loaders/LIDC_IDRI/` Code to pre-process a LION version of the dataset

`LION/data_loaders/LIDC_IDRI.py` Pytorch DataSet

## Loss functions

#### Steins Unbased risk estimator (SURE)

Metzler, Christopher A., et al. "Unsupervised learning with Stein's unbiased risk estimator." arXiv preprint arXiv:1805.10531 (2018).
[https://doi.org/10.48550/arXiv.1805.10531](https://doi.org/10.48550/arXiv.1805.10531)

`LION/losses/SURE.py` The loss function itself. Use with `SelfSupervisedSolver`

## Misc
Loading

0 comments on commit 8806f41

Please sign in to comment.