If you have been using data generators in Keras, such as ImageDataGenerator for augment and load the input data, then you would be familiar with the using the *_generator() methods (fit_generator(), evaluate_generator(), etc.) to pass the generators when trainning the model.
But recently, if you have switched to TensorFlow 2.1 or later (and tf.keras), you might have been getting a warning message such as,
Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.
Or,
Model.evaluate_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: Please use Model.evaluate, which supports generators.
fit_generator() Deprecation Warning |
This is because in tf.keras, as well as the latest version of multi-backend Keras, the model.fit() function can take generators as well.
Therefore, all *_generator() function calls can now be replaced with their respective non-generator function calls: fit() instead of fit_generator(), evaluate() instead of evaluate_generator(), and predict() instead of predict_generator().
For example, the model.fit() function can take the following inputs (source):
- A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
- A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
- A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
If you want to update the *_generator() function calls in your code the parameters of the function is exactly the same as before.
E.g. If you have the following code,
history = model.fit_generator( train_generator, steps_per_epoch=train_steps, epochs=train_epochs, validation_data=validation_generator, validation_steps=validation_steps, class_weight=class_weights, initial_epoch=init_epoch_train, max_queue_size=15, workers=8, callbacks=callbacks_list )
You can update it to the following,
history = model.fit( train_generator, steps_per_epoch=train_steps, epochs=train_epochs, validation_data=validation_generator, validation_steps=validation_steps, class_weight=class_weights, initial_epoch=init_epoch_train, max_queue_size=15, workers=8, callbacks=callbacks_list )
All behaviors will be the same as before, including the returned history object.
No comments:
Post a Comment