This project implements a simple Generative Adversarial Network (GAN) to generate MNIST-like images using PyTorch. It includes:
- A
Discriminator
class to distinguish between real and fake images. - A
Generator
class to generate new images from random noise. - A
main.py
script to train the GAN and log results to TensorBoard.
Code Source: This implementation is inspired by this YouTube tutorial by the channel.
- Utilizes PyTorch for model implementation and training.
- Trains on the MNIST dataset with normalized input data.
- Logs loss values and visualizes generated images in TensorBoard.
- Supports GPU training for faster performance.
.
├── src/
│ ├── models/
│ │ ├── discriminator.py # Defines the Discriminator class
│ │ ├── generator.py # Defines the Generator class
├── main.py # Main script for training the GAN
├── dataset/ # Directory where MNIST data is downloaded
├── runs/ # TensorBoard logs for real and fake images
├── requirements.txt # Python package dependencies
└── README.md # Project documentation
All dependencies are listed in the requirements.txt
file. Install them using the following command:
pip install -r requirements.txt
-
Clone the repository:
git clone https://github.com/your-username/simpleGAN.git cd simpleGAN
-
Create a virtual environment and activate it:
python -m venv env source env/bin/activate # For Linux/Mac env\Scripts\activate # For Windows
-
Install the required Python packages:
pip install -r requirements.txt
-
Run the
main.py
script to start training:python main.py
-
View training progress and generated images using TensorBoard:
tensorboard --logdir=runs
Open the URL provided by TensorBoard to visualize logs and images.
- Loss Values: Discriminator and generator loss values printed in the console during training.
- Generated Images: Fake images generated at each epoch, logged in TensorBoard.
- Modify hyperparameters such as learning rate, batch size, or the number of epochs in
main.py
. - Adjust the architecture of the
Discriminator
orGenerator
in their respective files undersrc/models/
.
This project is inspired by this YouTube tutorial.
This project is licensed under the MIT License. See the LICENSE.md file for details.
- Python Version: 3.8+
- PyTorch Version: 1.12+
- TensorBoard: Enabled