Skip to content

Commit

Permalink
Merge pull request #141 from cshoebridge/equivariance
Browse files Browse the repository at this point in the history
Equivariant Imaging
  • Loading branch information
AnderBiguri authored Nov 15, 2024
2 parents ba5696f + f040067 commit 3f8bbb0
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 54 deletions.
80 changes: 80 additions & 0 deletions LION/optimizers/EquivariantSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import random
from typing import Callable
import numpy as np
import torch
import torchvision.transforms.functional as TF
from torch.optim.optimizer import Optimizer
from tomosipo.torch_support import to_autograd
from tqdm import tqdm
from LION.CTtools.ct_geometry import Geometry
from LION.classical_algorithms.fdk import fdk
from LION.exceptions.exceptions import LIONSolverException
from LION.models.LIONmodel import LIONmodel, ModelInputType
from LION.optimizers.LIONsolver import LIONsolver, SolverParams


def get_rotation_matrix(angle: float):
theta = torch.tensor(angle)
s = torch.sin(theta)
c = torch.cos(theta)
return torch.tensor([[c, -s], [s, c]])


class EquivariantSolverParams(SolverParams):
def __init__(
self, transformation_group: list[Callable], equivariance_strength: float
):
super().__init__()
self.transformation_group = transformation_group
self.equivariance_strength = equivariance_strength


class EquivariantSolver(LIONsolver):
def __init__(
self,
model: LIONmodel,
optimizer: Optimizer,
loss_fn: Callable,
geometry: Geometry,
verbose: bool = True,
device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}"),
solver_params: SolverParams | None = None,
) -> None:
super().__init__(
model, optimizer, loss_fn, geometry, verbose, device, solver_params
)
self.transformation_group = self.solver_params.transformation_group
self.alpha = self.solver_params.equivariance_strength
self.A = to_autograd(self.op, num_extra_dims=1)
self.AT = to_autograd(self.op.T, num_extra_dims=1)

@staticmethod
def rotation_group(cardinality: int):
assert 360 % cardinality == 0
angle_increment = 360 / cardinality

return [lambda x: TF.rotate(x, i * angle_increment) for i in range(cardinality)]

@staticmethod
def default_parameters() -> EquivariantSolverParams:
return EquivariantSolverParams(EquivariantSolver.rotation_group(360), 100)

def mini_batch_step(self, sino_batch, target_batch) -> torch.Tensor:
random_transform = random.choice(self.transformation_group)

if needs_image := (self.model.get_input_type() == ModelInputType.IMAGE):
recon1 = self.model(fdk(sino_batch, self.op))
else:
recon1 = self.model(sino_batch)

transformed_recon1 = random_transform(recon1)

if needs_image:
recon2 = self.model(fdk(self.A(transformed_recon1), self.op))
else:
recon2 = self.model(self.A(transformed_recon1))

return self.loss_fn(self.A(recon1), sino_batch) + self.alpha * self.loss_fn(
recon2, transformed_recon1
)
# data consistency + equivariance
51 changes: 42 additions & 9 deletions LION/optimizers/LIONsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def load_checkpoint(self):
(
self.model,
self.optimizer,
epoch,
self.current_epoch,
self.train_loss,
_,
) = self.model.load_checkpoint_if_exists(
Expand All @@ -648,18 +648,20 @@ def load_checkpoint(self):
)
if (
self.validation_fn is not None
and epoch > 0
and self.current_epoch > 0
and self.validation_fname is not None
and self.validation_loss is not None
):
self.validation_loss[epoch - 1] = self.model._read_min_validation(
self.validation_loss[
self.current_epoch - 1
] = self.model._read_min_validation(
self.load_folder.joinpath(self.validation_fname)
)
if self.verbose:
print(
f"Loaded checkpoint at epoch {epoch}. Current min validation loss is {self.validation_loss[epoch-1]}"
f"Loaded checkpoint at epoch {self.current_epoch}. Current min validation loss is {self.validation_loss[self.current_epoch-1]}"
)
return epoch
return self.current_epoch

def train_step(self):
"""
Expand Down Expand Up @@ -688,7 +690,7 @@ def epoch_step(self, epoch):
"""
self.train_loss[epoch] = self.train_step()
# actually make sure we're doing validation
if (epoch + 1) % self.validation_freq == 0 and self.validation_loss is not None:
if self.validation_loss is not None and (epoch + 1) % self.validation_freq == 0:
self.validation_loss[epoch] = self.validate()
if self.verbose:
print(
Expand Down Expand Up @@ -734,12 +736,43 @@ def train(self, n_epochs):

self.current_epoch += 1

@abstractmethod
def validate(self):
"""
This function should perform a validation step
This function is responsible for performing a single validation set of the optimization.
returns the average loss of the validation set this epoch.
"""
pass
if self.check_validation_ready() != 0:
raise LIONSolverException(
"Solver not ready for validation. Please call set_validation."
)

# these always pass if the above does, this is just to placate static type checker
assert self.validation_loader is not None
assert self.validation_fn is not None

status = self.model.training
self.model.eval()

with torch.no_grad():
validation_loss = np.array([])
for data, targets in tqdm(self.validation_loader):
if self.model.get_input_type() == ModelInputType.IMAGE:
data = fdk(data, self.op)
outputs = self.model(data)
validation_loss = np.append(
validation_loss, self.validation_fn(targets, outputs)
)

if self.verbose:
print(
f"Testing loss: {validation_loss.mean()} - Testing loss std: {validation_loss.std()}"
)

# return to train if it was in train
if status:
self.model.train()

return np.mean(validation_loss)

@abstractmethod
def mini_batch_step(self, sino_batch, target_batch) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion LION/optimizers/noise2inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def default_parameters() -> Noise2InverseParams:
cali_J,
)

def mini_batch_step(self, sinos):
def mini_batch_step(self, sinos, targets):
# sinos batch of sinos
noisy_sub_recons = self._calculate_noisy_sub_recons(sinos)
# b, split, c, w, h
Expand Down
44 changes: 0 additions & 44 deletions LION/optimizers/supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,50 +54,6 @@ def mini_batch_step(self, sino, target):
output = self.model(data)
return self.loss_fn(output, target)

def validate(self):
"""
This function is responsible for performing a single validation set of the optimization.
returns the average loss of the validation set this epoch.
"""
"""
This function is responsible for performing a single validation set of the optimization.
returns the average loss of the validation set this epoch.
"""
if self.check_validation_ready() != 0:
raise LIONSolverException(
"Solver not ready for validation. Please call set_validation."
)

# these always pass if the above does, this is just to placate static type checker
assert self.validation_loader is not None
assert self.validation_fn is not None

status = self.model.training
self.model.eval()

with torch.no_grad():
validation_loss = np.array([])
for data, targets in tqdm(self.validation_loader):
print(self.model.model_parameters.model_input_type)
if self.model.get_input_type() == ModelInputType.IMAGE:
data = fdk(data, self.op)
print(data.shape)
outputs = self.model(data)
validation_loss = np.append(
validation_loss, self.validation_fn(targets, outputs)
)

if self.verbose:
print(
f"Testing loss: {validation_loss.mean()} - Testing loss std: {validation_loss.std()}"
)

# return to train if it was in train
if status:
self.model.train()

return np.mean(validation_loss)

@staticmethod
def default_parameters() -> SolverParams:
return SolverParams()
Loading

0 comments on commit 3f8bbb0

Please sign in to comment.