Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module memory #12

Merged
merged 8 commits into from
Jan 2, 2024
Binary file removed .DS_Store
Binary file not shown.
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ To use VPRTempo, please follow the instructions below for installation and usage
- Quantization Aware Training (QAT) enabled to train weights in int8 space
- Addition of tutorials in Jupyter Notebooks to learn how to use VPRTempo as well as explain the computational logic
- Simplification of weight operations, reducing to a single weight tensor - allowing positive and negative connections to change sign during training
- Easier dependency installation with PyPi/pip
- Easier dependency installation with PyPi/pip and conda
- And more!

## License & Citation
Expand Down Expand Up @@ -64,22 +64,28 @@ If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get
Dependencies can be installed either through our provided `requirements.txt` files.

```python
pip3 install -r requirements.txt
pip install -r requirements.txt
```
As above, if you wish to install CUDA please visit [PyTorch - Get Started](https://pytorch.org/get-started/locally/).
### Option 3: Conda install
>**:heavy_exclamation_mark: Recommended:**
> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) instead of conda.
> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) instead of conda.

Requirements for VPRTempo may be installed using our [conda-forge package](https://anaconda.org/conda-forge/vprtempo).

```console
# Windows/Linux - CUDA enabled
conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn
# Linux/OS X
conda create -n vprtempo -c conda-forge vprtempo

# Linux CUDA enabled
conda create -n vprtempo -c conda-forge -c pytorch -c nvidia vprtempo pytorch-cuda cudatoolkit

# Windows/Linux - CPU only
conda create -n vprtempo python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn -c pytorch
# Windows
conda create -n vprtempo -c pytorch python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn

# Windows CUDA enabled
conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn

# MacOS
conda create -n vprtempo -c conda-forge python prettytable tqdm numpy pandas scikit-learn -c pytorch pytorch::pytorch torchvision torchaudio
```

## Datasets
Expand Down Expand Up @@ -142,8 +148,6 @@ python main.py --quantize
<img src="./assets/mainquant_example.gif" alt="Example of the quantized VPRTempo networking running"/>
</p>

#### IDE
You can also run VPRTempo through your IDE by running `main.py`. Change the `bool` flag for `use_quantize` to `True` if you wish to run VPRTempoQuant.

### Train new network
If you do not wish to use the pretrained models or you would like to train your own, we can parse the `--train_new_model` flag to `main.py`. Note, if a pretrained model already exists you will be prompted if you would like to retrain it.
Expand Down
5 changes: 5 additions & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_site
.sass-cache
.jekyll-cache
.jekyll-metadata
vendor
180 changes: 145 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,79 +23,179 @@
'''
Imports
'''
import os
import sys
import torch
import argparse

import torch.quantization as quantization

from tqdm import tqdm
from vprtempo.VPRTempo import VPRTempo, run_inference
from vprtempo.VPRTempoTrain import VPRTempoTrain, train_new_model
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, train_new_model_quant

def generate_model_name(model,quant=False):
"""
Generate the model name based on its parameters.
"""
if quant:
model_name = (''.join(model.database_dirs)+"_"+
"VPRTempoQuant_" +
"IN"+str(model.input)+"_" +
"FN"+str(model.feature)+"_" +
"DB"+str(model.database_places) +
".pth")
else:
model_name = (''.join(model.database_dirs)+"_"+
"VPRTempo_" +
"IN"+str(model.input)+"_" +
"FN"+str(model.feature)+"_" +
"DB"+str(model.database_places) +
".pth")
return model_name

def check_pretrained_model(model_name):
"""
Check if a pre-trained model exists and prompt the user to retrain if desired.
"""
if os.path.exists(os.path.join('./vprtempo/models', model_name)):
prompt = "A network with these parameters exists, re-train network? (y/n):\n"
retrain = input(prompt).strip().lower()
if retrain == 'y':
return True
elif retrain == 'n':
print('Training new model cancelled')
sys.exit()

def initialize_and_run_model(args,dims):
"""
Run the VPRTempo/VPRTempoQuant training or inference models.

:param args: Arguments set for the network
:param dims: Dimensions of the network
"""
# Determine number of modules to generate based on user input
places = args.database_places # Copy out number of database places

# Caclulate number of modules
num_modules = 1
while places > args.max_module:
places -= args.max_module
num_modules += 1

# If the final module has less than max_module, reduce the dim of the output layer
remainder = args.database_places % args.max_module
if remainder != 0: # There are remainders, adjust output neuron count in final module
out_dim = int((args.database_places - remainder) / (num_modules - 1))
final_out_dim = remainder
else: # No remainders, all modules are even
out_dim = int(args.database_places / num_modules)
final_out_dim = out_dim

# If user wants to train a new network
if args.train_new_model:
# If using quantization aware training
if args.quantize:
models = []
logger = model_logger_quant()
# Get the quantization config
logger = model_logger_quant() # Initialize the logger
qconfig = quantization.get_default_qat_qconfig('fbgemm')
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempoQuantTrain(args, dims, logger)
# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
model.train()
model.qconfig = qconfig
models.append(model)
quantization.prepare_qat(model, inplace=True)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim
# Generate the model name
model_name = generate_model_name_quant(model)
model_name = generate_model_name(model,args.quantize)
# Check if the model has been trained before
check_pretrained_model(model_name)
# Get the quantization config
qconfig = quantization.get_default_qat_qconfig('fbgemm')
# Train the model
train_new_model_quant(models, model_name, qconfig)
else: # Normal model
train_new_model_quant(models, model_name)

# Base model
else:
models = []
logger = model_logger()
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempoTrain(args, dims, logger)
models.append(model)
logger = model_logger() # Initialize the logger

# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim

# Generate the model name
model_name = generate_model_name(model)
print(f"Model name: {model_name}")
# Check if the model has been trained before
check_pretrained_model(model_name)
# Train the model
train_new_model(models, model_name)

# Run the inference network
else:
# Set the quantization configuration
if args.quantize:
models = []
logger = model_logger_quant()
logger, output_folder = model_logger_quant()
qconfig = quantization.get_default_qat_qconfig('fbgemm')
for _ in range(args.num_modules):
final_out = None
for _ in tqdm(range(num_modules), desc="Initializing modules"):
# Initialize the model
model = VPRTempoQuant(dims, args, logger)
model = VPRTempoQuant(
args,
dims,
logger,
num_modules,
output_folder,
out_dim,
out_dim_remainder=final_out
)
model.eval()
model.qconfig = qconfig
model = quantization.prepare(model, inplace=False)
model = quantization.convert(model, inplace=False)
quantization.prepare(model, inplace=True)
quantization.convert(model, inplace=True)
models.append(model)
# Generate the model name
model_name = generate_model_name_quant(model)
model_name = generate_model_name(model, args.quantize)
# Run the quantized inference model
run_inference_quant(models, model_name, qconfig)
run_inference_quant(models, model_name)
else:
models = []
logger = model_logger()
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempo(dims, args, logger)
models.append(model)
logger, output_folder = model_logger() # Initialize the logger
places = args.database_places # Copy out number of database places

# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempo(
args,
dims,
logger,
num_modules,
output_folder,
out_dim,
out_dim_remainder=final_out
)
model.eval()
model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim
# Generate the model name
model_name = generate_model_name(model)
print(f"Model name: {model_name}")
# Run the inference model
run_inference(models, model_name)

Expand All @@ -110,19 +210,23 @@ def parse_network(use_quantize=False, train_new_model=False):
help="Dataset to use for training and/or inferencing")
parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
help="Directory where dataset files are stored")
parser.add_argument('--num_places', type=int, default=500,
help="Number of places to use for training and/or inferencing")
parser.add_argument('--num_modules', type=int, default=1,
help="Number of expert modules to use split images into")
parser.add_argument('--database_places', type=int, default=500,
help="Number of places to use for training")
parser.add_argument('--query_places', type=int, default=500,
help="Number of places to use for inferencing")
parser.add_argument('--max_module', type=int, default=500,
help="Maximum number of images per module")
parser.add_argument('--database_dirs', nargs='+', default=['spring', 'fall'],
parser.add_argument('--database_dirs', type=str, default='spring, fall',
help="Directories to use for training")
parser.add_argument('--query_dir', nargs='+', default=['summer'],
parser.add_argument('--query_dir', type=str, default='summer',
help="Directories to use for testing")
parser.add_argument('--shuffle', action='store_true',
help="Shuffle input images during query")
parser.add_argument('--GT_tolerance', type=int, default=1,
help="Ground truth tolerance for matching")

# Define training parameters
parser.add_argument('--filter', type=int, default=8,
parser.add_argument('--filter', type=int, default=1,
help="Images to skip for training and/or inferencing")
parser.add_argument('--epoch', type=int, default=4,
help="Number of epochs to train the model")
Expand All @@ -139,6 +243,12 @@ def parse_network(use_quantize=False, train_new_model=False):
parser.add_argument('--quantize', action='store_true',
help="Enable/disable quantization for the model")

# Define metrics functionality
parser.add_argument('--PR_curve', action='store_true',
help="Flag to generate a Precision-Recall curve")
parser.add_argument('--sim_mat', action='store_true',
help="Flag to plot the similarity matrix, GT, and GTsoft")

# If the function is called with specific arguments, override sys.argv
if use_quantize or train_new_model:
sys.argv = ['']
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# define the setup
setup(
name="VPRTempo",
version="1.1.4",
version="1.1.5",
description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
Binary file removed tutorials/.DS_Store
Binary file not shown.
Binary file removed tutorials/mats/.DS_Store
Binary file not shown.
Loading