Keras using generator on a limited resources environment sequence

14 views Asked by At

Having some problem of ressources on my environment, I tried to implement an MixedImageDataGenerator relying on underlying ImageDataGenerator, in order to generate Inmges and lables at training time.

The idea was to generate all images with a limitation for each underlyine generators for each epochs refreshing the limits between epochs.

I 'm running the following version: Keras version: 2.3.1 TensorFlow version: 2.1.0

Here is the code of the MixedImageDataGenerator:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import threading



class MixedImageDataGenerator(object):
    """
    Wrapper class that mixes data from multiple ImageDataGenerator instances.

    Args:
        generators (list): List of ImageDataGenerator objects.
        limits (list): List of integers representing the number of batches
                        to be generated from each generator before switching.
        generator_type: type of generato 0 = train, 1 = validate , 2 = test                

    Attributes:
        generators (list): List of ImageDataGenerator objects.
        limits (list): List of integers representing batch limits.
        current_generator_index (int): Index of the currently active generator.
        remaining_batches (list): List of integers representing remaining
                                  batches per generator.
        data_queue (deque): Queue to store generated batches (optional for efficiency).
    """
    TYPE = ['train','validate','test']
    TRAIN =  0
    VALIDATE = 1
    TEST = 2
    
    def __init__(self, generators, limits,generator_type,debug=False):
        self.lock = threading.Lock()
        self.generators = generators
        self.limits = limits
        self.current_generator_index = 0
        self.remaining_batches = limits.copy()  # Directly assign the limits list
        self.debug=debug
        if self.debug:
            print(f'Create Generator :{MixedImageDataGenerator.TYPE[generator_type]}')
        self.generator_type = generator_type
    def get_step_number(self):
        return sum(self.limits)
    def get_remaining_batches(self):
        return sum(self.remaining_batches)
    
    def getType(self):
        return self.generator_type
    
    def getGeneratorNumber(self):
        return(len(self.generators))
    
    def __iter__(self):
        """
        Makes the MixedImageDataGenerator object iterable,
        allowing it to be used in a for loop.

        Returns:
            self: The MixedImageDataGenerator object itself.
        """ 
        if  self.get_remaining_batches() == 0:
            return self.repeat()
        else:
            return self


    def __next__(self):
        """
        Returns the next batch of data from the mixed generators.

        Raises:
            StopIteration: If all limits have been exhausted.
        """
        if self.debug:
            print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next called [pre lock]')
        with self.lock:
            #if self.debug:
            print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next call ...')
            while True:
                #print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]Remaining batch:  {self.remaining_batches} limits: {self.limits}')
                if sum(self.remaining_batches) <= 0:
                    #if self.debug:
                    print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] all limits exhausted')
                    raise StopIteration
                
                data = None
                while data is None:
                    # Check if current generator has remaining batches
                    if self.remaining_batches[self.current_generator_index] > 0:
                        # Ensure the current generator is advanced if its queue is empty
                        data = self.generators[self.current_generator_index].next() 
                        # Decrease the remaining batches of the current generator
                        self.remaining_batches[self.current_generator_index] -= 1
                        if self.debug:
                            print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]Remaining batch:  {self.remaining_batches} limits: {self.limits}')

                        
                    # Move to the next generator
                    self.current_generator_index = (self.current_generator_index + 1) % len(self.generators)
                    # If all generators have exhausted their limits, raise StopIteration
                    if all(remaining_batch <= 0 for remaining_batch in self.remaining_batches):
                        if self.debug:
                            print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] all step data generated')
                        raise StopIteration
                if self.debug:
                    print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: RETURN {type(data)}')
                print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next called')
                if not data:
                    print('EMPTY================')
                
                return data

                
    def refresh(self, argument=None):      
        with self.lock:
            if self.debug:
                print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] refresh')
            self.current_generator_index = 0
            self.remaining_batches = self.limits.copy()  # Directly assign the limits list
            # self.data_queue = deque(maxlen=10)  # Optional queue for efficiency

    def repeat(self):
        print('repeat called')
        while True:
            try:
                yield next(self)
            except StopIteration:
                self.refresh()       ```

The following call back, should take care for the refresh limits 


""" Custom callback to log epoch and step information, and refresh the generator. """ class EpochStepLogger(tf.keras.callbacks.Callback): def init(self, generator): """ Args: generator: The data generator object to be refreshed. """ self.generator = generator self.current_epoch = 0 self.current_step = 0

def on_epoch_begin(self, epoch, logs=None):
    """
    Called at the beginning of each epoch.

    Args:
      epoch: The current epoch number.
      logs: Dictionary of logs accumulated during the previous epoch.
    """
    self.current_epoch = epoch
    self.current_step = 0
    self.log(f"Epoch {epoch} started!")
    self.refresh_generator()

    

def on_batch_end(self, batch, logs=None):
    """
    Called at the end of each batch.

    Args:
      batch: The index of the current batch.
      logs: Dictionary of logs accumulated at the end of the current batch.
    """
    self.current_step += 1
    self.log(f"Epoch {self.current_epoch}, Step {self.current_step} completed.")
    # in case where there is more than one call to next() per step
    # if self.generator:
    #   if self.generator.get_remaining_batches() < 1:
    #     self.refresh_generator()
    

def on_train_end(self, logs=None):
    """
    Called at the end of training.

    Args:
      logs: Dictionary of logs accumulated at the end of training.
    """
    self.log("Training completed!")

def on_train_begin(self, logs=None):
    """
    Called at the begining of training.

    Args:
      logs: Dictionary of logs accumulated at the end of training.
    """
    self.refresh_generator()
    self.log("Training started!")    

def refresh_generator(self):
    """
    Refreshes the data generator.
    """
    if self.generator:
        self.generator.refresh()  # Assuming your generator has a refresh method

def log(self, message):
    """
    Logs a message with epoch and step information.
    """
    print(f"[Epoch {self.current_epoch}, Step {self.current_step}] {message}")
    
def on_test_begin(self,logs=None):
    print(f'[ON_TEST_BEGIN]')

def on_test_end(self,logs=None):
    print(f'[ON_TEST_END]')    ```

This is the code to train the model:


"""
Custom callback to log epoch and step information, and refresh the generator.
"""
class EpochStepLogger(tf.keras.callbacks.Callback):
    def __init__(self, generator):
        """
        Args:
          generator: The data generator object to be refreshed.
        """
        self.generator = generator
        self.current_epoch = 0
        self.current_step = 0

    def on_epoch_begin(self, epoch, logs=None):
        """
        Called at the beginning of each epoch.

        Args:
          epoch: The current epoch number.
          logs: Dictionary of logs accumulated during the previous epoch.
        """
        self.current_epoch = epoch
        self.current_step = 0
        self.log(f"Epoch {epoch} started!")
        self.refresh_generator()

        

    def on_batch_end(self, batch, logs=None):
        """
        Called at the end of each batch.

        Args:
          batch: The index of the current batch.
          logs: Dictionary of logs accumulated at the end of the current batch.
        """
        self.current_step += 1
        self.log(f"Epoch {self.current_epoch}, Step {self.current_step} completed.")
        # in case where there is more than one call to next() per step
        # if self.generator:
        #   if self.generator.get_remaining_batches() < 1:
        #     self.refresh_generator()
        

    def on_train_end(self, logs=None):
        """
        Called at the end of training.

        Args:
          logs: Dictionary of logs accumulated at the end of training.
        """
        self.log("Training completed!")
    
    def on_train_begin(self, logs=None):
        """
        Called at the begining of training.

        Args:
          logs: Dictionary of logs accumulated at the end of training.
        """
        self.refresh_generator()
        self.log("Training started!")    

    def refresh_generator(self):
        """
        Refreshes the data generator.
        """
        if self.generator:
            self.generator.refresh()  # Assuming your generator has a refresh method

    def log(self, message):
        """
        Logs a message with epoch and step information.
        """
        print(f"[Epoch {self.current_epoch}, Step {self.current_step}] {message}")
        
    def on_test_begin(self,logs=None):
        print(f'[ON_TEST_BEGIN]')
    
    def on_test_end(self,logs=None):
        print(f'[ON_TEST_END]')    ```

With 3 underlying generators within my MixedDataGenerator with a limits of: [80, 80, 80] 

I got the following trace:

`[Epoch 0, Step 223] Epoch 0, Step 223 completed.
[train]: next called [pre lock]
[train]: next call ...
[train]Remaining batch:  [0, 0, 0] limits: [80, 80, 80]
[train] all step data generated
[Epoch 0, Step 224] Epoch 0, Step 224 completed.
[Epoch 0, Step 225] Epoch 0, Step 225 completed.
[Epoch 0, Step 226] Epoch 0, Step 226 completed.
[Epoch 0, Step 227] Epoch 0, Step 227 completed.
[Epoch 0, Step 228] Epoch 0, Step 228 completed.
[Epoch 0, Step 229] Epoch 0, Step 229 completed.
[Epoch 0, Step 230] Epoch 0, Step 230 completed.
[Epoch 0, Step 231] Epoch 0, Step 231 completed.
[Epoch 0, Step 232] Epoch 0, Step 232 completed.
[Epoch 0, Step 233] Epoch 0, Step 233 completed.
[Epoch 0, Step 234] Epoch 0, Step 234 completed.
[Epoch 0, Step 235] Epoch 0, Step 235 completed.
[Epoch 0, Step 236] Epoch 0, Step 236 completed.
[Epoch 0, Step 237] Epoch 0, Step 237 completed.
[Epoch 0, Step 238] Epoch 0, Step 238 completed.
[Epoch 0, Step 239] Epoch 0, Step 239 completed.
[Epoch 0, Step 240] Epoch 0, Step 240 completed.
[ON_TEST_BEGIN]
[ON_TEST_END]

Epoch 00001: val_loss improved from inf to 1.20102, saving model to E__chest_ray_VGG_F_2XDO_512_best.h5
240/240 - 209s - loss: 404.7601 - accuracy: 0.3000 - val_loss: 1.2010 - val_accuracy: 0.4000
[Epoch 1, Step 0] Epoch 1 started!
[train] refresh
Epoch 2/7
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1680 batches). You may need to use the repeat() function when building your dataset.
WARNING:tensorflow:Can save best model only with val_loss available, skipping.
WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: 
[Epoch 1, Step 0] Training completed!
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-77-3cc9dfbc0a95> in <module>
     78                        #class_weight=class_weight,
     79                        batch_size=model_batch_size,
---> 80                        verbose=2)  
     81 # End timing
     82 end_time = time.time()

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    340                 mode=ModeKeys.TRAIN,
    341                 training_context=training_context,
--> 342                 total_epochs=epochs)
    343             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
    344 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
    185 
    186   # End of an epoch.
--> 187   aggregator.finalize()
    188   return aggregator.results
    189 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in finalize(self)
    142   def finalize(self):
    143     if not self.results:
--> 144       raise ValueError('Empty training data.')
    145     self.results[0] /= (self.num_samples or self.steps)
    146 

ValueError: Empty training data.

After training for some time, look at the performa`

Specifying the steps_per_epoch is also difficult because it seems that I can't predict what the fix method does behind the hood.
I'm stuck, does someone have an idea?

david 
0

There are 0 answers