-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Online/Stream Training Design #49
Comments
The PyTorch design looks good to me. It is a little difficult to understand if we can incur performance penalties when running at scale, I guess we will have to run some tests on real-world examples, or at least on some SOTA topology and dataset (e.g. we could populate the Imagenet dataset while some large ResNet instance is being trained). My thoughts about the open questions:
|
Here is what the streaming approach looks like in Keras/TF: class SmartSimDataGenerator(keras.utils.Sequence):
def __init__(self, batch_size=32,
n_classes=1000, shuffle=True,
producer_prefix="", sample_id="batch_",
label_id="labels_"):
self.sample_id = sample_id
self.label_id = label_id
self.client = Client(None, False)
self.next_index = {}
for entity_name in environ["SSKEYIN"].split(','):
print(entity_name)
if entity_name.startswith(producer_prefix):
self.next_index[entity_name] = 0
self.samples = None
self.indices = None
self.n_classes = n_classes
self.shuffle = shuffle
while self.samples is None:
self.on_epoch_end()
self.batch_size = batch_size
def __len__(self):
return int(np.floor(len(self.samples) / self.batch_size))
def __getitem__(self, index):
# Generate indexes of the batch
indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
# Generate data
x, y = self.__data_generation(indices)
return x, y
def on_epoch_end(self):
for entity in self.next_index:
self.client.set_data_source(entity)
batch_name = self.sample_id + str(self.next_index[entity])
label_name = self.label_id + str(self.next_index[entity])
print(f"Retrieving {batch_name} and {label_name} from {entity}...")
while self.client.tensor_exists(batch_name) and self.client.tensor_exists(label_name):
if self.samples is None:
self.samples = self.client.get_tensor(batch_name)
self.labels = self.client.get_tensor(label_name)
self.dim = self.samples.shape[1:]
else:
self.samples = np.concatenate((self.samples,self.client.get_tensor(batch_name)))
self.labels = np.concatenate((self.labels, self.client.get_tensor(label_name)))
print("Success!")
self.next_index[entity] += 1
batch_name = self.sample_id + str(self.next_index[entity])
label_name = self.label_id + str(self.next_index[entity])
print(f"Retrieving {batch_name} and {label_name}...")
self.indices = np.arange(start=0, stop=len(self.samples))
if self.shuffle:
np.random.shuffle(self.indices)
def __data_generation(self, indices):
# Initialization
x = self.samples[indices]
y = self.labels[indices]
return x, keras.utils.to_categorical(y, num_classes=self.n_classes) At the end of each epoch, all available batches are added to the samples.
Now: (1) is the simplest, but requires N calls per producer and is pretty static (N and M are known at the beginning and cannot change), (2) requires an intermediate step, (3) requires some initial setup. (2) and (3) could be changed at runtime, if M or N changes, with some work to do for case (2) (regroup, split), less to do for (3) (change mapping, keep training, but this means that each worker needs to know the mapping every time it downloads a batch). (2) would be something a separate thread could facilitate: it could run on any node, or as a Redis module, possibly. I think that we could start by including into the experimental features the data loaders for Torch and TF as we have designed them, and then see how to proceed from there. |
Description
SmartSim was designed primarily in part so that multiple processes in the supported SmartRedis languages can share data at runtime. One specific use case for this is to be able to train ML models that consume data from a simulation or distributed HPC workload.
Justification
Simulation output in total can be very large, however, the data needed for training is often much smaller than what is actually saved to file. Often times, only a few fields need to be pulled at each timestep (or every few timesteps), to train a model. As well, simulations used for training often need to be run multiple times, either to produce needed output data or to run across ranges of free parameters spaces.
SmartSim can be used to execute simulation(s) and SmartRedis can be embedded into the simulation(s) to stream only the needed data points for training to Redis to be consumed by some training process.
The entire pipeline looks like:
Simulation(s) ----> Redis server(s) ----> training process
Optionally, once trained for a sufficient number of epochs/rounds/timesteps, trained models can be serialized and sent to Redis to be consumed either by the original data producer, or a third application.
Simulation(s) <====> Redis Server(s) <====> training process
The second workflow opens the door to dynamic model updates based on new simulation data or conditions. This needs more research, but the point of this issue is largely to gather feedback and determine functionality needed to support the first use case.
Implementation Strategy
The only pieces that need to be created in order to leverage SmartSim for online/stream training are a dataset and dataloader for the training process that utilizes the SmartRedis client to pull samples/batches for training.
A sample implementation with PyTorch is provided as the basis for discussion.
The dataset above follows the PyTorch Iterable dataset convention. Every timestep, the producer application (not shown here) will use SmartRedis to stream a set of batches and targets (assuming supervised learning) to the
Orchestrator
(Redis) that can be consumed by a training process. This dataset will wait for those indices to be populated before consuming them for training.The dataset is fairly flexible. In the case were one has decomposed their simulation domain into multiple keys and needs the entire domain for training, domain reconstruction strategies can be inserted into the
__iter__
function.The next piece is the dataloader that consumes the
StreamRedisDataset
.The dataloader can use multiple workers (multiple SmartRedis clients) to consume data in parallel. We initialize the SmartRedis client within the
worker_init_fn
because the SmartRedis client cannot be pickled and PyTorch uses python multiprocessing which requires all objects in the process to be "pickle-able".We use persistent workers because we want to the clients we initialized to not be destroyed at the end of each epoch. This is mostly because client initialization is expensive. Note that persistent workers only important if there is a need to perform multiple epochs on each batch of data to be iterated on. (more on this later)
Given the above dataloader and dataset, the following is a sample training process:
this training process does multiple epochs on multiple batches for multiple "rounds". The concept of training rounds allows the training process to perform multiple epochs on each set of batches streamed from the data producer. The rounds are used to set and increment indices of batches to be consumed for many epochs. One could just as easily remove the concept of round and use epochs in a similar fashion if you only wanted to do one epoch per set of batches. This would be closer to a scikit-learn
partial fit
approach.the example doesn't use shuffling because we want the indices we use to be consumed in order and we dont use batching because the producer, in this case, sends the samples and targets in batches. Batching could easily be turned on.
Lastly, we delete the dataset specifically at the end of each round so that the worker processes are killed and the clients are disconnected.
this particular example waits for 100 batches to be populated for each round of training. So the training process looks like
Training starts
Database behaviors for training
One other piece of note is that through using this (or similar) approach with SmartSim for online training, users can train on datasets of theoretically infinite size being produced by a simulation. The reason for this is because of the way users can change the behavior of Redis.
key = 'name_i % 5
Overall this is a decent sketch of how we are doing online training right now but I think it could be improved. I would appreciate any feedback on ways in which users think these dataloaders/sets should be implemented.
Open Questions
Should we cache batches on the training server that are pulled for each round?
How should the dataloaders/datasets be included in the project?
smartsim.data
to hold dataloaders for each ML framework we support. i.e.smartsim.data.torch
andsmartsim.data.tf
.TODO
(more will be added to this issue soon)
The text was updated successfully, but these errors were encountered: