Having some problem of ressources on my environment, I tried to implement an MixedImageDataGenerator relying on underlying ImageDataGenerator, in order to generate Inmges and lables at training time.
The idea was to generate all images with a limitation for each underlyine generators for each epochs refreshing the limits between epochs.
I 'm running the following version: Keras version: 2.3.1 TensorFlow version: 2.1.0
Here is the code of the MixedImageDataGenerator:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import threading
class MixedImageDataGenerator(object):
"""
Wrapper class that mixes data from multiple ImageDataGenerator instances.
Args:
generators (list): List of ImageDataGenerator objects.
limits (list): List of integers representing the number of batches
to be generated from each generator before switching.
generator_type: type of generato 0 = train, 1 = validate , 2 = test
Attributes:
generators (list): List of ImageDataGenerator objects.
limits (list): List of integers representing batch limits.
current_generator_index (int): Index of the currently active generator.
remaining_batches (list): List of integers representing remaining
batches per generator.
data_queue (deque): Queue to store generated batches (optional for efficiency).
"""
TYPE = ['train','validate','test']
TRAIN = 0
VALIDATE = 1
TEST = 2
def __init__(self, generators, limits,generator_type,debug=False):
self.lock = threading.Lock()
self.generators = generators
self.limits = limits
self.current_generator_index = 0
self.remaining_batches = limits.copy() # Directly assign the limits list
self.debug=debug
if self.debug:
print(f'Create Generator :{MixedImageDataGenerator.TYPE[generator_type]}')
self.generator_type = generator_type
def get_step_number(self):
return sum(self.limits)
def get_remaining_batches(self):
return sum(self.remaining_batches)
def getType(self):
return self.generator_type
def getGeneratorNumber(self):
return(len(self.generators))
def __iter__(self):
"""
Makes the MixedImageDataGenerator object iterable,
allowing it to be used in a for loop.
Returns:
self: The MixedImageDataGenerator object itself.
"""
if self.get_remaining_batches() == 0:
return self.repeat()
else:
return self
def __next__(self):
"""
Returns the next batch of data from the mixed generators.
Raises:
StopIteration: If all limits have been exhausted.
"""
if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next called [pre lock]')
with self.lock:
#if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next call ...')
while True:
#print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]Remaining batch: {self.remaining_batches} limits: {self.limits}')
if sum(self.remaining_batches) <= 0:
#if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] all limits exhausted')
raise StopIteration
data = None
while data is None:
# Check if current generator has remaining batches
if self.remaining_batches[self.current_generator_index] > 0:
# Ensure the current generator is advanced if its queue is empty
data = self.generators[self.current_generator_index].next()
# Decrease the remaining batches of the current generator
self.remaining_batches[self.current_generator_index] -= 1
if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]Remaining batch: {self.remaining_batches} limits: {self.limits}')
# Move to the next generator
self.current_generator_index = (self.current_generator_index + 1) % len(self.generators)
# If all generators have exhausted their limits, raise StopIteration
if all(remaining_batch <= 0 for remaining_batch in self.remaining_batches):
if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] all step data generated')
raise StopIteration
if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: RETURN {type(data)}')
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}]: next called')
if not data:
print('EMPTY================')
return data
def refresh(self, argument=None):
with self.lock:
if self.debug:
print(f'[{MixedImageDataGenerator.TYPE[self.generator_type]}] refresh')
self.current_generator_index = 0
self.remaining_batches = self.limits.copy() # Directly assign the limits list
# self.data_queue = deque(maxlen=10) # Optional queue for efficiency
def repeat(self):
print('repeat called')
while True:
try:
yield next(self)
except StopIteration:
self.refresh() ```
The following call back, should take care for the refresh limits
""" Custom callback to log epoch and step information, and refresh the generator. """ class EpochStepLogger(tf.keras.callbacks.Callback): def init(self, generator): """ Args: generator: The data generator object to be refreshed. """ self.generator = generator self.current_epoch = 0 self.current_step = 0
def on_epoch_begin(self, epoch, logs=None):
"""
Called at the beginning of each epoch.
Args:
epoch: The current epoch number.
logs: Dictionary of logs accumulated during the previous epoch.
"""
self.current_epoch = epoch
self.current_step = 0
self.log(f"Epoch {epoch} started!")
self.refresh_generator()
def on_batch_end(self, batch, logs=None):
"""
Called at the end of each batch.
Args:
batch: The index of the current batch.
logs: Dictionary of logs accumulated at the end of the current batch.
"""
self.current_step += 1
self.log(f"Epoch {self.current_epoch}, Step {self.current_step} completed.")
# in case where there is more than one call to next() per step
# if self.generator:
# if self.generator.get_remaining_batches() < 1:
# self.refresh_generator()
def on_train_end(self, logs=None):
"""
Called at the end of training.
Args:
logs: Dictionary of logs accumulated at the end of training.
"""
self.log("Training completed!")
def on_train_begin(self, logs=None):
"""
Called at the begining of training.
Args:
logs: Dictionary of logs accumulated at the end of training.
"""
self.refresh_generator()
self.log("Training started!")
def refresh_generator(self):
"""
Refreshes the data generator.
"""
if self.generator:
self.generator.refresh() # Assuming your generator has a refresh method
def log(self, message):
"""
Logs a message with epoch and step information.
"""
print(f"[Epoch {self.current_epoch}, Step {self.current_step}] {message}")
def on_test_begin(self,logs=None):
print(f'[ON_TEST_BEGIN]')
def on_test_end(self,logs=None):
print(f'[ON_TEST_END]') ```
This is the code to train the model:
"""
Custom callback to log epoch and step information, and refresh the generator.
"""
class EpochStepLogger(tf.keras.callbacks.Callback):
def __init__(self, generator):
"""
Args:
generator: The data generator object to be refreshed.
"""
self.generator = generator
self.current_epoch = 0
self.current_step = 0
def on_epoch_begin(self, epoch, logs=None):
"""
Called at the beginning of each epoch.
Args:
epoch: The current epoch number.
logs: Dictionary of logs accumulated during the previous epoch.
"""
self.current_epoch = epoch
self.current_step = 0
self.log(f"Epoch {epoch} started!")
self.refresh_generator()
def on_batch_end(self, batch, logs=None):
"""
Called at the end of each batch.
Args:
batch: The index of the current batch.
logs: Dictionary of logs accumulated at the end of the current batch.
"""
self.current_step += 1
self.log(f"Epoch {self.current_epoch}, Step {self.current_step} completed.")
# in case where there is more than one call to next() per step
# if self.generator:
# if self.generator.get_remaining_batches() < 1:
# self.refresh_generator()
def on_train_end(self, logs=None):
"""
Called at the end of training.
Args:
logs: Dictionary of logs accumulated at the end of training.
"""
self.log("Training completed!")
def on_train_begin(self, logs=None):
"""
Called at the begining of training.
Args:
logs: Dictionary of logs accumulated at the end of training.
"""
self.refresh_generator()
self.log("Training started!")
def refresh_generator(self):
"""
Refreshes the data generator.
"""
if self.generator:
self.generator.refresh() # Assuming your generator has a refresh method
def log(self, message):
"""
Logs a message with epoch and step information.
"""
print(f"[Epoch {self.current_epoch}, Step {self.current_step}] {message}")
def on_test_begin(self,logs=None):
print(f'[ON_TEST_BEGIN]')
def on_test_end(self,logs=None):
print(f'[ON_TEST_END]') ```
With 3 underlying generators within my MixedDataGenerator with a limits of: [80, 80, 80]
I got the following trace:
`[Epoch 0, Step 223] Epoch 0, Step 223 completed.
[train]: next called [pre lock]
[train]: next call ...
[train]Remaining batch: [0, 0, 0] limits: [80, 80, 80]
[train] all step data generated
[Epoch 0, Step 224] Epoch 0, Step 224 completed.
[Epoch 0, Step 225] Epoch 0, Step 225 completed.
[Epoch 0, Step 226] Epoch 0, Step 226 completed.
[Epoch 0, Step 227] Epoch 0, Step 227 completed.
[Epoch 0, Step 228] Epoch 0, Step 228 completed.
[Epoch 0, Step 229] Epoch 0, Step 229 completed.
[Epoch 0, Step 230] Epoch 0, Step 230 completed.
[Epoch 0, Step 231] Epoch 0, Step 231 completed.
[Epoch 0, Step 232] Epoch 0, Step 232 completed.
[Epoch 0, Step 233] Epoch 0, Step 233 completed.
[Epoch 0, Step 234] Epoch 0, Step 234 completed.
[Epoch 0, Step 235] Epoch 0, Step 235 completed.
[Epoch 0, Step 236] Epoch 0, Step 236 completed.
[Epoch 0, Step 237] Epoch 0, Step 237 completed.
[Epoch 0, Step 238] Epoch 0, Step 238 completed.
[Epoch 0, Step 239] Epoch 0, Step 239 completed.
[Epoch 0, Step 240] Epoch 0, Step 240 completed.
[ON_TEST_BEGIN]
[ON_TEST_END]
Epoch 00001: val_loss improved from inf to 1.20102, saving model to E__chest_ray_VGG_F_2XDO_512_best.h5
240/240 - 209s - loss: 404.7601 - accuracy: 0.3000 - val_loss: 1.2010 - val_accuracy: 0.4000
[Epoch 1, Step 0] Epoch 1 started!
[train] refresh
Epoch 2/7
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1680 batches). You may need to use the repeat() function when building your dataset.
WARNING:tensorflow:Can save best model only with val_loss available, skipping.
WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are:
[Epoch 1, Step 0] Training completed!
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-77-3cc9dfbc0a95> in <module>
78 #class_weight=class_weight,
79 batch_size=model_batch_size,
---> 80 verbose=2)
81 # End timing
82 end_time = time.time()
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
817 max_queue_size=max_queue_size,
818 workers=workers,
--> 819 use_multiprocessing=use_multiprocessing)
820
821 def evaluate(self,
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
340 mode=ModeKeys.TRAIN,
341 training_context=training_context,
--> 342 total_epochs=epochs)
343 cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
344
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
185
186 # End of an epoch.
--> 187 aggregator.finalize()
188 return aggregator.results
189
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in finalize(self)
142 def finalize(self):
143 if not self.results:
--> 144 raise ValueError('Empty training data.')
145 self.results[0] /= (self.num_samples or self.steps)
146
ValueError: Empty training data.
After training for some time, look at the performa`
Specifying the steps_per_epoch is also difficult because it seems that I can't predict what the fix method does behind the hood.
I'm stuck, does someone have an idea?
david