Skip to content

Commit d1430be

Browse files
author
Simon Karlsson
committed
add(CycleGAN model and ReadME)
1 parent 7cca835 commit d1430be

10 files changed

+1284
-0
lines changed

CycleGAN/CycleGAN.py

+931
Large diffs are not rendered by default.

CycleGAN/data/ReadMe.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#### Directory structure on new dataset needed for training and testing:
2+
3+
```
4+
./Dataset-name/trainA
5+
./Dataset-name/trainB
6+
./Dataset-name/testA
7+
./Dataset-name/testB
8+
```

CycleGAN/generate_images/ReadMe.md

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
### When generating synthetic images from a trained models:
2+
3+
#### Put models in directory:
4+
```
5+
./models/
6+
```
7+
8+
#### With names:
9+
```
10+
G_A2B_model.hdf5
11+
G_B2A_model.hdf5
12+
```
13+
14+
#### Create directories for generated images:
15+
```
16+
./synthetic_images/A
17+
./synthetic_images/B
18+
```
19+
20+
#### Comment row 242:
21+
```
22+
#self.train(…
23+
```
24+
25+
#### Uncomment row 243:
26+
```
27+
self.load_model_and_generate_synthetic_images()
28+
```
29+
30+
#### Then run:
31+
```
32+
python CycleGAN.py
33+
```

CycleGAN/load_data.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import numpy as np
3+
from PIL import Image
4+
from keras.utils import Sequence
5+
#from skimage.io import imread
6+
7+
8+
def load_data(nr_of_channels, batch_size=1, nr_A_train_imgs=None, nr_B_train_imgs=None,
9+
nr_A_test_imgs=None, nr_B_test_imgs=None, subfolder='',
10+
generator=False, D_model=None, use_multiscale_discriminator=False, use_supervised_learning=False, REAL_LABEL=1.0):
11+
12+
trainA_path = os.path.join('data', subfolder, 'trainA')
13+
trainB_path = os.path.join('data', subfolder, 'trainB')
14+
testA_path = os.path.join('data', subfolder, 'testA')
15+
testB_path = os.path.join('data', subfolder, 'testB')
16+
17+
trainA_image_names = os.listdir(trainA_path)
18+
if nr_A_train_imgs != None:
19+
trainA_image_names = trainA_image_names[:nr_A_train_imgs]
20+
21+
trainB_image_names = os.listdir(trainB_path)
22+
if nr_B_train_imgs != None:
23+
trainB_image_names = trainB_image_names[:nr_B_train_imgs]
24+
25+
testA_image_names = os.listdir(testA_path)
26+
if nr_A_test_imgs != None:
27+
testA_image_names = testA_image_names[:nr_A_test_imgs]
28+
29+
testB_image_names = os.listdir(testB_path)
30+
if nr_B_test_imgs != None:
31+
testB_image_names = testB_image_names[:nr_B_test_imgs]
32+
33+
if generator:
34+
return data_sequence(trainA_path, trainB_path, trainA_image_names, trainB_image_names, batch_size=batch_size) # D_model, use_multiscale_discriminator, use_supervised_learning, REAL_LABEL)
35+
else:
36+
trainA_images = create_image_array(trainA_image_names, trainA_path, nr_of_channels)
37+
trainB_images = create_image_array(trainB_image_names, trainB_path, nr_of_channels)
38+
testA_images = create_image_array(testA_image_names, testA_path, nr_of_channels)
39+
testB_images = create_image_array(testB_image_names, testB_path, nr_of_channels)
40+
return {"trainA_images": trainA_images, "trainB_images": trainB_images,
41+
"testA_images": testA_images, "testB_images": testB_images,
42+
"trainA_image_names": trainA_image_names,
43+
"trainB_image_names": trainB_image_names,
44+
"testA_image_names": testA_image_names,
45+
"testB_image_names": testB_image_names}
46+
47+
48+
def create_image_array(image_list, image_path, nr_of_channels):
49+
image_array = []
50+
for image_name in image_list:
51+
if image_name[-1].lower() == 'g': # to avoid e.g. thumbs.db files
52+
if nr_of_channels == 1: # Gray scale image -> MR image
53+
image = np.array(Image.open(os.path.join(image_path, image_name)))
54+
image = image[:, :, np.newaxis]
55+
else: # RGB image -> 3 channels
56+
image = np.array(Image.open(os.path.join(image_path, image_name)))
57+
image = normalize_array(image)
58+
image_array.append(image)
59+
60+
return np.array(image_array)
61+
62+
63+
def normalize_array(array):
64+
max_value = max(array.flatten())
65+
array = array / max_value
66+
return array
67+
68+
69+
class data_sequence(Sequence):
70+
71+
def __init__(self, trainA_path, trainB_path, image_list_A, image_list_B, batch_size=1): # , D_model, use_multiscale_discriminator, use_supervised_learning, REAL_LABEL):
72+
self.batch_size = batch_size
73+
self.train_A = []
74+
self.train_B = []
75+
for image_name in image_list_A:
76+
if image_name[-1].lower() == 'g': # to avoid e.g. thumbs.db files
77+
self.train_A.append(os.path.join(trainA_path, image_name))
78+
for image_name in image_list_B:
79+
if image_name[-1].lower() == 'g': # to avoid e.g. thumbs.db files
80+
self.train_B.append(os.path.join(trainB_path, image_name))
81+
82+
def __len__(self):
83+
return int(max(len(self.train_A), len(self.train_B)) / float(self.batch_size))
84+
85+
def __getitem__(self, idx): # , use_multiscale_discriminator, use_supervised_learning):if loop_index + batch_size >= min_nr_imgs:
86+
if idx >= min(len(self.train_A), len(self.train_B)):
87+
# If all images soon are used for one domain,
88+
# randomly pick from this domain
89+
if len(self.train_A) <= len(self.train_B):
90+
indexes_A = np.random.randint(len(self.train_A), size=self.batch_size)
91+
batch_A = []
92+
for i in indexes_A:
93+
batch_A.append(self.train_A[i])
94+
batch_B = self.train_B[idx * self.batch_size:(idx + 1) * self.batch_size]
95+
else:
96+
indexes_B = np.random.randint(len(self.train_B), size=self.batch_size)
97+
batch_B = []
98+
for i in indexes_B:
99+
batch_B.append(self.train_B[i])
100+
batch_A = self.train_A[idx * self.batch_size:(idx + 1) * self.batch_size]
101+
else:
102+
batch_A = self.train_A[idx * self.batch_size:(idx + 1) * self.batch_size]
103+
batch_B = self.train_B[idx * self.batch_size:(idx + 1) * self.batch_size]
104+
105+
real_images_A = create_image_array(batch_A, '', 3)
106+
real_images_B = create_image_array(batch_B, '', 3)
107+
108+
return real_images_A, real_images_B # input_data, target_data
109+
110+
111+
if __name__ == '__main__':
112+
load_data()

CycleGAN/plotCSVfile.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import csv
2+
import matplotlib.pyplot as plt
3+
import sys
4+
#from operator import add
5+
import numpy as np
6+
7+
from scipy.signal import butter, lfilter, freqz
8+
9+
10+
def butter_lowpass(cutoff, fs, order=5):
11+
nyq = 0.5 * fs
12+
normal_cutoff = cutoff / nyq
13+
b, a = butter(order, normal_cutoff, btype='low', analog=False)
14+
return b, a
15+
16+
17+
def butter_lowpass_filter(data, cutoff, fs, order=5):
18+
b, a = butter_lowpass(cutoff, fs, order=order)
19+
y = lfilter(b, a, data)
20+
return y
21+
22+
23+
def plotResultfromCSV(datetime, point_gap=1):
24+
DA_losses = []
25+
DB_losses = []
26+
gA_d_losses_synthetic = []
27+
gB_d_losses_synthetic = []
28+
gA_losses_reconstructed = []
29+
gB_losses_reconstructed = []
30+
D_losses = []
31+
G_losses = []
32+
reconstruction_losses = []
33+
34+
with open('images/{}/loss_output.csv'.format(datetime), newline='') as csvfile:
35+
reader = csv.DictReader(csvfile)
36+
for row in reader:
37+
DA_losses.append(float(row['DA_losses']))
38+
DB_losses.append(float(row['DB_losses']))
39+
gA_d_losses_synthetic.append(float(row['gA_d_losses_synthetic']))
40+
gB_d_losses_synthetic.append(float(row['gB_d_losses_synthetic']))
41+
gA_losses_reconstructed.append(float(row['gA_losses_reconstructed']))
42+
gB_losses_reconstructed.append(float(row['gB_losses_reconstructed']))
43+
D_losses.append(float(row['D_losses']))
44+
reconstruction_losses.append(float(row['reconstruction_losses']))
45+
G_loss = row['G_losses']
46+
if G_loss[0] == '[':
47+
G_loss = G_loss.split(',')[0][1:]
48+
G_losses.append(float(G_loss))
49+
csvfile.close()
50+
51+
# Calculate interesting things to plot
52+
DA_losses = np.array(DA_losses)
53+
DB_losses = np.array(DB_losses)
54+
GA_losses = np.add(np.array(gA_d_losses_synthetic), np.array(gA_losses_reconstructed))
55+
GB_losses = np.add(np.array(gB_d_losses_synthetic), np.array(gB_losses_reconstructed))
56+
RA_losses = np.array(gA_losses_reconstructed)
57+
RB_losses = np.array(gB_losses_reconstructed)
58+
59+
G_losses = np.array(G_losses)
60+
D_losses = np.array(D_losses)
61+
reconstruction_losses = np.add(np.array(gA_losses_reconstructed), np.array(gB_losses_reconstructed))
62+
63+
points = range(0, len(G_losses), point_gap)
64+
fs = 1000
65+
cutoff = 2
66+
order = 6
67+
68+
# Lowpass filter
69+
GA = butter_lowpass_filter(GA_losses[points], cutoff, fs, order)
70+
GB = butter_lowpass_filter(GB_losses[points], cutoff, fs, order)
71+
72+
DA = butter_lowpass_filter(DA_losses[points], cutoff, fs, order)
73+
DB = butter_lowpass_filter(DB_losses[points], cutoff, fs, order)
74+
75+
RA = butter_lowpass_filter(RA_losses[points], cutoff, fs, order)
76+
RB = butter_lowpass_filter(RB_losses[points], cutoff, fs, order)
77+
78+
G = butter_lowpass_filter(G_losses[points], cutoff, fs, order)
79+
D = butter_lowpass_filter(D_losses[points], cutoff, fs, order)
80+
R = butter_lowpass_filter(reconstruction_losses[points], cutoff, fs, order)
81+
82+
fig_D = plt.figure(1)
83+
plt.plot(GA, label='GB_losses')
84+
plt.plot(GB, label='GA_losses')
85+
plt.ylabel('Generator losses')
86+
plt.legend()
87+
88+
fig_G = plt.figure(2)
89+
plt.plot(DA, label='DA_losses')
90+
plt.plot(DB, label='DB_losses')
91+
plt.ylabel('Discriminator losses')
92+
plt.legend()
93+
94+
fig_recons = plt.figure(3)
95+
plt.plot(RA, label='Reconstruction_loss_A')
96+
plt.plot(RB, label='Reconstruction_loss_B')
97+
plt.ylabel('Reconstruction losses')
98+
plt.legend()
99+
100+
fig_tots = plt.figure(4)
101+
plt.plot(G, label='G_losses')
102+
plt.plot(D, label='D_losses')
103+
plt.plot(R, label='Reconstruction_losses')
104+
plt.legend()
105+
106+
# Show plots
107+
fig_D.show()
108+
fig_G.show()
109+
fig_recons.show()
110+
fig_tots.show()
111+
112+
plt.pause(0)
113+
114+
115+
if __name__ == '__main__':
116+
datetime = str(sys.argv[1])
117+
points = int(sys.argv[2])
118+
plotResultfromCSV(datetime, points)

README.md

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
3+
## Generative Adversarial Networks for Image-to-Image Translation on Multi-Contrast MR Images - A Comparison of CycleGAN and UNIT
4+
5+
### Article
6+
[Frontiers in Neuroinformatics - article](https://www.frontiersin.org/journals/neuroinformatics)
7+
8+
[(Underlying thesis project - report)](http://liu.diva-portal.org/smash/record.jsf?dswid=-7667&aq2=%5B%5B%5D%5D&af=%5B%5D&searchType=SIMPLE&sortOrder2=title_sort_asc&language=en&pid=diva2%3A1216606&aq=%5B%5B%5D%5D&sf=all&aqe=%5B%5D&sortOrder=author_sort_asc&onlyFullText=false&noOfRows=50&dspwid=-7667)
9+
&nbsp;
10+
&nbsp;
11+
&nbsp;
12+
13+
### Code usage
14+
1. Prepare your dataset under the directory 'data' in the CycleGAN or UNIT folder
15+
* Directory structure on new dataset needed for training and testing:
16+
* data/Dataset-name/trainA
17+
* data/Dataset-name/trainB
18+
* data/Dataset-name/testA
19+
* data/Dataset-name/testB
20+
&nbsp;
21+
2. Train a model by:
22+
```
23+
python CycleGAN.py
24+
```
25+
or
26+
```
27+
python UNIT.py
28+
```
29+
&nbsp;
30+
3. Generate synthetic images by following specifications under:
31+
* CycleGAN/generate_images/ReadMe.rtf
32+
* UNIT/generate_images/ReadMe.rtf
33+
&nbsp;
34+
&nbsp;
35+
&nbsp;
36+
37+
### Result GIFs - 304x256 pixel images
38+
**Left:** Input image. **Middle:** Synthetic images generated during training. **Right:** Ground truth.
39+
Histograms show pixel value distributions for synthetic images (blue) compared to ground truth (brown).
40+
&nbsp;
41+
&nbsp;
42+
&nbsp;
43+
44+
#### CycleGAN - T1 to T2
45+
![](./ReadMe/gifs/CycleGAN_T2_hist.gif?)
46+
---
47+
&nbsp;
48+
&nbsp;
49+
&nbsp;
50+
&nbsp;
51+
&nbsp;
52+
&nbsp;
53+
54+
#### CycleGAN - T2 to T1
55+
![](./ReadMe/gifs/CycleGAN_T1_hist.gif)
56+
---
57+
&nbsp;
58+
&nbsp;
59+
&nbsp;
60+
&nbsp;
61+
&nbsp;
62+
&nbsp;
63+
64+
#### UNIT - T1 to T2
65+
![](./ReadMe/gifs/UNIT_T2_hist.gif)
66+
---
67+
&nbsp;
68+
&nbsp;
69+
&nbsp;
70+
&nbsp;
71+
&nbsp;
72+
&nbsp;
73+
74+
#### UNIT - T2 to T1
75+
![](./ReadMe/gifs/UNIT_T1_hist.gif)
76+
---
77+
&nbsp;
78+
&nbsp;
79+
&nbsp;
80+
&nbsp;
81+
&nbsp;
82+
&nbsp;

ReadMe/gifs/CycleGAN_T1_hist.gif

26.5 MB
Loading

ReadMe/gifs/CycleGAN_T2_hist.gif

26.6 MB
Loading

ReadMe/gifs/UNIT_T1_hist.gif

24.5 MB
Loading

ReadMe/gifs/UNIT_T2_hist.gif

24 MB
Loading

0 commit comments

Comments
 (0)