In the model I want to launch, I have some variables which have to be initialized with specific values.
I currently store these variables into numpy arrays but I don
I tried the accepted answer but ran into some problems. Eventually this worked for me (Python 3):
from io import BytesIO
import numpy as np
from tensorflow.python.lib.io import file_io
To save:
dest = 'gs://[BUCKET-NAME]/' # Destination to save in GCS
np.save(file_io.FileIO(dest, 'w'), np.ones((100, )))
To load:
f = BytesIO(file_io.read_file_to_string(src, binary_mode=True))
arr = np.load(f)
First, you'll need to copy/store the data on GCS (using, e.g., gsutil
) and ensure your training script has access to that bucket. The easiest way to do so is to copy the array to the same bucket as your data, since you'll likely already have configured that bucket for read access. If the bucket is in the same project as your training job and you followed these instructions (particularly, gcloud beta ml init-project
), you should be set. If the data will be in another bucket, see these instructions.
Then you'll need to use a library capable of loading data from GCS. Tensorflow includes a module that can do this, although you're free to use any client library that can read from GCS. Here's an example of using TensorFlow's file_io
module:
from StringIO import StringIO
import tensorflow as tf
import numpy as np
from tensorflow.python.lib.io import file_io
# Create a variable initialized to the value of a serialized numpy array
f = StringIO(file_io.read_file_to_string('gs://my-bucket/123.npy'))
my_variable = tf.Variable(initial_value=np.load(f), name='my_variable')
Note that we have to read the file into a string and use StringIO
, since file_io.FileIO
does not fully implement the seek function required by numpy.load
.
Bonus: in case it's useful, you can directly store a numpy array to GCS using the file_io
module, e.g.:
np.save(file_io.FileIO('gs://my-bucket/123', 'w'), np.array([[1,2,3], [4,5,6]]))
For Python 3, use from io import StringIO
instead of from StringIO import StringIO
.