问题
I have some code in TensorFlow which takes a base model, fine-tunes (trains) it with some data, and then uses the model to predict()
using some other data. All this is encapsulated in a main()
method of a module and works fine.
When I run this code in a loop over different base models, however, I end up with an OOM after, e.g., 7 base models. Is this expected? I would expect that Python cleans up after each main()
call. Does TensorFlow not do that? How can I force it to?
Edit: here's an MWE showing not the OOM crashes, but increasing memory consumption:
import gc
import os
import numpy as np
import psutil
import tensorflow as tf
tf.get_logger().setLevel("ERROR") # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
(model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
history = model.fit(
x=(x := tf.zeros((1, *model.input.shape[1:]))),
y=(y := tf.zeros((1, *model.output.shape[1:]))),
verbose=0,
)
prediction = model.predict(x)
_ = gc.collect()
# tf.keras.backend.clear_session()
print(f"rss {i}: {process.memory_info().rss >> 20} MB")
On my computer (CPU), it prints
rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...
With tf.keras.backend.clear_session()
uncommented, it's better, but not perfect yet:
rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB
Switching the order of gc.collect()
and tf.keras.backend.clear_session()
did not help, either.
来源:https://stackoverflow.com/questions/63411142/how-to-avoid-oom-errors-in-repeated-training-and-prediction-in-tensorflow