Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online/Stream Training Design #49

Closed
Spartee opened this issue May 26, 2021 · 2 comments
Closed

Online/Stream Training Design #49

Spartee opened this issue May 26, 2021 · 2 comments
Assignees
Labels
type: design Issues related to architecture and code design type: feature Issues that include feature request or feature idea

Comments

@Spartee
Copy link
Contributor

Spartee commented May 26, 2021

Description

SmartSim was designed primarily in part so that multiple processes in the supported SmartRedis languages can share data at runtime. One specific use case for this is to be able to train ML models that consume data from a simulation or distributed HPC workload.

Justification

Simulation output in total can be very large, however, the data needed for training is often much smaller than what is actually saved to file. Often times, only a few fields need to be pulled at each timestep (or every few timesteps), to train a model. As well, simulations used for training often need to be run multiple times, either to produce needed output data or to run across ranges of free parameters spaces.

SmartSim can be used to execute simulation(s) and SmartRedis can be embedded into the simulation(s) to stream only the needed data points for training to Redis to be consumed by some training process.

The entire pipeline looks like:

Simulation(s) ----> Redis server(s) ----> training process

Optionally, once trained for a sufficient number of epochs/rounds/timesteps, trained models can be serialized and sent to Redis to be consumed either by the original data producer, or a third application.

Simulation(s) <====> Redis Server(s) <====> training process

The second workflow opens the door to dynamic model updates based on new simulation data or conditions. This needs more research, but the point of this issue is largely to gather feedback and determine functionality needed to support the first use case.

Implementation Strategy

The only pieces that need to be created in order to leverage SmartSim for online/stream training are a dataset and dataloader for the training process that utilizes the SmartRedis client to pull samples/batches for training.

A sample implementation with PyTorch is provided as the basis for discussion.

import torch
import torch.nn as nn
import torch.optim as opt
import torchvision.transforms as transforms
import math

from smartredis import Client


class SmartRedisStream(torch.utils.data.IterableDataset):

    def __init__(self,
                 start,
                 end,
                 host='127.0.0.1',
                 port=6379,
                 cluster=False,
                 sample_prefix="batch_",
                 target_prefix="targets_",
                 transform=None):
        
        super(SmartRedisStream).__init__()
        self.client = None
        self.start = start
        self.end = end
        self.address = ":".join((host, str(port)))
        self.cluster = cluster
        self.sample_prefix = sample_prefix
        self.target_prefix = target_prefix
        self.transform = transform

    def __iter__(self):
        for index in range(self.start, self.end):

            batch_key = "".join((self.sample_prefix, str(index)))
            target_key = "".join((self.target_prefix, str(index)))

            key_exists = self.client.poll_key(batch_key, 5000, 60)
            if not key_exists:
                raise Exception("Timeout waiting for new data to train on")
 
            data = self.client.get_tensor(batch_key)
            target = self.client.get_tensor(target_key)
            data = torch.as_tensor(data, dtype=torch.double)
            target = torch.as_tensor(target, dtype=torch.double)
            
            yield data, target

The dataset above follows the PyTorch Iterable dataset convention. Every timestep, the producer application (not shown here) will use SmartRedis to stream a set of batches and targets (assuming supervised learning) to the Orchestrator (Redis) that can be consumed by a training process. This dataset will wait for those indices to be populated before consuming them for training.

The dataset is fairly flexible. In the case were one has decomposed their simulation domain into multiple keys and needs the entire domain for training, domain reconstruction strategies can be inserted into the __iter__ function.

The next piece is the dataloader that consumes the StreamRedisDataset.

class SmartRedisStreamer(torch.utils.data.DataLoader):
    
    def __init__(self, dataset, **kwargs):
        super().__init__(dataset,
                         worker_init_fn=self.worker_init_fn,
                         persistent_workers=True,
                         **kwargs)
    
    @staticmethod
    def worker_init_fn(worker_id):
        print("initing worker", flush=True)
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset  # the dataset copy in this worker process
        overall_start = dataset.start
        overall_end = dataset.end

        # configure the dataset to only process the split workload
        per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
        worker_id = worker_info.id
        dataset.start = overall_start + worker_id * per_worker
        dataset.end = min(dataset.start + per_worker, overall_end)

        # init client and pass to dataset object
        dataset.client = Client(address=dataset.address, cluster=dataset.cluster)

The dataloader can use multiple workers (multiple SmartRedis clients) to consume data in parallel. We initialize the SmartRedis client within the worker_init_fn because the SmartRedis client cannot be pickled and PyTorch uses python multiprocessing which requires all objects in the process to be "pickle-able".

We use persistent workers because we want to the clients we initialized to not be destroyed at the end of each epoch. This is mostly because client initialization is expensive. Note that persistent workers only important if there is a need to perform multiple epochs on each batch of data to be iterated on. (more on this later)

Given the above dataloader and dataset, the following is a sample training process:

def fit(model, epochs, optim, loss, rounds):
    
    start = 0
    batches_per_round = 100
    end = batches_per_round

    for _round in range(rounds):
        print(f"Round {str(_round)}", flush=True)

        dataset = SmartRedisStream(start, end)
        loader = SmartRedisStreamer(
            dataset,
            batch_size=None,   # handling batching ourselves
            shuffle=False,         # has to be false so indices don't get messed up
            num_workers=3
        )

        for epoch in range(epochs):
            for data, target in loader:
                _loss = loss(model(data), target)
                _loss.backward()
                optim.step()
                optim.zero_grad()
            if epoch % 10 == 0:
                print(f"Epoch {epoch}, loss: {_loss.item()}", flush=True)

        del loader
        del dataset
        
        start = end
        end += batches_per_round

this training process does multiple epochs on multiple batches for multiple "rounds". The concept of training rounds allows the training process to perform multiple epochs on each set of batches streamed from the data producer. The rounds are used to set and increment indices of batches to be consumed for many epochs. One could just as easily remove the concept of round and use epochs in a similar fashion if you only wanted to do one epoch per set of batches. This would be closer to a scikit-learn partial fit approach.

the example doesn't use shuffling because we want the indices we use to be consumed in order and we dont use batching because the producer, in this case, sends the samples and targets in batches. Batching could easily be turned on.

Lastly, we delete the dataset specifically at the end of each round so that the worker processes are killed and the clients are disconnected.

this particular example waits for 100 batches to be populated for each round of training. So the training process looks like

Training starts

  • Round 1 - batches 0-100
    • epoch 1 (consumes batches 0-100)
    • epoch 2 (consumes batches 0-100)
    • ...
    • epoch n (consumes batches 0-100)
  • Round 2 -batches 100-200
    • epoch 1 (consumes batches 100-200)
    • epoch 2 (consumes batches 100-200)
    • ...
    • epoch n (consumes batches 100-200)
  • Round 3...
  • ...
  • ...
  • Round n -batches m-m+100

Database behaviors for training

One other piece of note is that through using this (or similar) approach with SmartSim for online training, users can train on datasets of theoretically infinite size being produced by a simulation. The reason for this is because of the way users can change the behavior of Redis.

  1. Overwrite every key (potential dataset size: inf)
    • Users can create keys (to store tensor values at) at each time step and overwrite existing memory locations.
    • The training process would grab whatever latest set of keys available when the next "round" (or epoch) began.
    • In this way, the database also gains back time that would be lost in memory allocations
  2. Overwrite sets of keys (potential dataset size: inf)
    • Same approach as above but overwriting only every 5 timesteps. i.e. key = 'name_i % 5
  3. No overwrite, no eviction (potential dataset size: compute node memory x database nodes)
    • increment indices at each timestep and store all streamed fields in the database
  4. No overwrite, LRU eviction (potential dataset size: inf)
    • Use Orchestrator (Redis) like it's originally designed to be used: as LRU cache
    • Could pull multiple timesteps or just latest timestep
    • Eviction adds a little overhead
    • Keys last pulled by the training server get evicted when database fills up.

Overall this is a decent sketch of how we are doing online training right now but I think it could be improved. I would appreciate any feedback on ways in which users think these dataloaders/sets should be implemented.

Open Questions

  1. Should we cache batches on the training server that are pulled for each round?

    • Right now the clients pull the batches before each epoch. In most cases, this will be fine as its faster than loading from file anyway, however, for some larger batches it might be better to cache in the memory local to the training process.
  2. How should the dataloaders/datasets be included in the project?

    • We could use smartsim.data to hold dataloaders for each ML framework we support. i.e. smartsim.data.torch and smartsim.data.tf.
    • Solely include them as examples in the documentation or examples folder
    • Separate repos with full examples?

TODO

  • Look into distributed dataloading
  • Investigate POC with SmartSim Ray for online RL
  • Look into making the SR client "pickleable"

(more will be added to this issue soon)

@Spartee Spartee added type: design Issues related to architecture and code design type: feature Issues that include feature request or feature idea labels May 26, 2021
@al-rigazzi
Copy link
Collaborator

The PyTorch design looks good to me. It is a little difficult to understand if we can incur performance penalties when running at scale, I guess we will have to run some tests on real-world examples, or at least on some SOTA topology and dataset (e.g. we could populate the Imagenet dataset while some large ResNet instance is being trained).

My thoughts about the open questions:

  1. This might be tricky to tune. Only fetching the next batch will generate a lot of requests to the db nodes, but sometimes it will be needed, because it could not be possible to cache an epoch worth of data on the training nodes. I would not worry about this before we can get some performance results and some feedback from users.
  2. I am in favor of smartsim.data.<toolkit>. Have them only in examples will make it harder to maintain them, while it might be a bit too soon to have separate repos (in the future, if we have more DL-related features, it will make sense to re-iterate).

@al-rigazzi
Copy link
Collaborator

al-rigazzi commented Oct 22, 2021

Here is what the streaming approach looks like in Keras/TF:

class SmartSimDataGenerator(keras.utils.Sequence):
    def __init__(self, batch_size=32,
                 n_classes=1000, shuffle=True,
                 producer_prefix="", sample_id="batch_",
                 label_id="labels_"):
        self.sample_id = sample_id
        self.label_id = label_id
        self.client = Client(None, False)
        self.next_index = {}
        for entity_name in environ["SSKEYIN"].split(','):
            print(entity_name)
            if entity_name.startswith(producer_prefix):
                self.next_index[entity_name] = 0
        self.samples = None
        self.indices = None
        self.n_classes = n_classes
        self.shuffle = shuffle
        while self.samples is None:
            self.on_epoch_end()

        self.batch_size = batch_size

    def __len__(self):
        return int(np.floor(len(self.samples) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]

        # Generate data
        x, y = self.__data_generation(indices)

        return x, y

    def on_epoch_end(self):

        for entity in self.next_index:
            self.client.set_data_source(entity)
            batch_name = self.sample_id + str(self.next_index[entity])
            label_name = self.label_id + str(self.next_index[entity])
            print(f"Retrieving {batch_name} and {label_name} from {entity}...")
            while self.client.tensor_exists(batch_name) and self.client.tensor_exists(label_name):
                if self.samples is None:
                    self.samples = self.client.get_tensor(batch_name)
                    self.labels = self.client.get_tensor(label_name)
                    self.dim = self.samples.shape[1:]
                else:
                    self.samples = np.concatenate((self.samples,self.client.get_tensor(batch_name)))
                    self.labels = np.concatenate((self.labels, self.client.get_tensor(label_name)))
                print("Success!")
                self.next_index[entity] += 1
                batch_name = self.sample_id + str(self.next_index[entity])
                label_name = self.label_id + str(self.next_index[entity])
                print(f"Retrieving {batch_name} and {label_name}...")
                
        self.indices = np.arange(start=0, stop=len(self.samples))
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __data_generation(self, indices):
        
        # Initialization
        x = self.samples[indices]
        y = self.labels[indices]

        return x, keras.utils.to_categorical(y, num_classes=self.n_classes)

At the end of each epoch, all available batches are added to the samples.
Clearly, this only works in the case where there are samples and labels. Other cases would require different implementations. For distributed training, I have some suggestions, if we want to train each of the N workers on a separate slice of the data:

  1. each one of the M producers could put N batches (basically, the producer would chunk the data and upload it in different batches)
  2. or each producer could upload one single batch and then run a script which splits it into smaller batches. This could be called by Worker #0, in a Horovod setup, for example
  3. or we could build some mapping between producers and workers (still, we should be sure that all producers are statistically similar...)

Now: (1) is the simplest, but requires N calls per producer and is pretty static (N and M are known at the beginning and cannot change), (2) requires an intermediate step, (3) requires some initial setup. (2) and (3) could be changed at runtime, if M or N changes, with some work to do for case (2) (regroup, split), less to do for (3) (change mapping, keep training, but this means that each worker needs to know the mapping every time it downloads a batch).

(2) would be something a separate thread could facilitate: it could run on any node, or as a Redis module, possibly.

I think that we could start by including into the experimental features the data loaders for Torch and TF as we have designed them, and then see how to proceed from there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type: design Issues related to architecture and code design type: feature Issues that include feature request or feature idea
Projects
None yet
Development

No branches or pull requests

2 participants