Pages

Tuesday, September 22, 2020

Using model.fit() instead of fit_generator() with Data Generators - TF.Keras

If you have been using data generators in Keras, such as ImageDataGenerator for augment and load the input data, then you would be familiar with the using the *_generator() methods (fit_generator(), evaluate_generator(), etc.) to pass the generators when trainning the model. 

But recently, if you have switched to TensorFlow 2.1 or later (and tf.keras), you might have been getting a warning message such as,

Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.

Or,

Model.evaluate_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.evaluate, which supports generators.


fit_generator() Deprecation Warning
fit_generator() Deprecation Warning

This is because in tf.keras, as well as the latest version of multi-backend Keras, the model.fit() function can take generators as well. 


Therefore, all *_generator() function calls can now be replaced with their respective non-generator function calls: fit() instead of fit_generator(), evaluate() instead of evaluate_generator(), and predict() instead of predict_generator().

For example, the model.fit() function can take the following inputs (source):

  • A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
  • A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
  • A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
  • A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
  • A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)


If you want to update the *_generator() function calls in your code the parameters of the function is exactly the same as before.

E.g. If you have the following code,

history = model.fit_generator(
                    train_generator,
                    steps_per_epoch=train_steps,
                    epochs=train_epochs,
                    validation_data=validation_generator,
                    validation_steps=validation_steps,
                    class_weight=class_weights,
                    initial_epoch=init_epoch_train,
                    max_queue_size=15,
                    workers=8,
                    callbacks=callbacks_list
                    )

You can update it to the following,

history = model.fit(
                    train_generator,
                    steps_per_epoch=train_steps,
                    epochs=train_epochs,
                    validation_data=validation_generator,
                    validation_steps=validation_steps,
                    class_weight=class_weights,
                    initial_epoch=init_epoch_train,
                    max_queue_size=15,
                    workers=8,
                    callbacks=callbacks_list
                    )


All behaviors will be the same as before, including the returned history object.






No comments:

Post a Comment