An error is being generated while training a federated model that uses hub.KerasLayer. The details of error and stack trace is given below. The complete code is available of gist https://gist.github.com/aksingh2411/60796ee58c88e0c3f074c8909b17b5a1. Help and suggestion in this regard would be appreciated. Thanks.
from tensorflow import keras
def create_keras_model():
encoder = hub.load("https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1")
return tf.keras.models.Sequential([
hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(16, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
# Building the Federated Averaging Process
iterative_process = tff.learning.build_federated_averaging_process(
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
UnimplementedError Traceback (most recent call last)
<ipython-input-80-39d62fa827ea> in <module>()
----> 1 state, metrics = iterative_process.next(state, federated_train_data)
2 print('round 1, metrics={}'.format(metrics))
119 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in
quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
UnimplementedError: Cast string to float is not supported
[[{{node StatefulPartitionedCall_1/StatefulPartitionedCall/Cast_1}}]]
[[import/StatefulPartitionedCall_3/ReduceDataset]] [Op:__inference_wrapped_function_65986]
Function call stack:
wrapped_function -> wrapped_function -> wrapped_function
The issues has now been resolved. The error was thrown because 'label' was getting passed as tf.string instead of tf.int32. Explicit casting resolved this.