问题
I tried to customize the model in "Image classification" tutorial in Tensorflow Federated. (It originally used a sequential model) I use Keras ResNet50 but when it began to train, there is always an error "Incompatible shapes"
Here are my codes:
NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5
def create_compiled_keras_model():
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet',
input_tensor=tf.keras.layers.Input(shape=(100,
300, 3)), pooling=None)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
Error information: enter image description here
I feel that the shape is incompatible because the epoch and clients information were somehow missing. Would be very thankful if someone could give me a hint.
Updates:
The Assertion error happened during tff.learning.build_federated_averaging_process
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
2
3 # iterative_process = build_federated_averaging_process(model_fn)
13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
165 return optimizer_utils.build_model_delta_optimizer_process(
166 model_fn, client_fed_avg, server_optimizer_fn,
--> 167 stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
349 # still need this.
350 with tf.Graph().as_default():
--> 351 dummy_model_for_metadata = model_utils.enhance(model_fn())
352
353 # ===========================================================================
<ipython-input-159-b2763ace8e5b> in model_fn()
1 def model_fn():
2 keras_model = model
----> 3 return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
211 # Model.test_on_batch() once before asking for metrics.
212 if isinstance(dummy_tensors, collections.Mapping):
--> 213 keras_model.test_on_batch(**dummy_tensors)
214 else:
215 keras_model.test_on_batch(*dummy_tensors)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
1007 sample_weight=sample_weight,
1008 reset_metrics=reset_metrics,
-> 1009 standalone=True)
1010 outputs = (
1011 outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
503 y,
504 sample_weights=sample_weights,
--> 505 output_loss_metrics=model._output_loss_metrics)
506
507 if reset_metrics:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
568 xla_context.Exit()
569 else:
--> 570 result = self._call(*args, **kwds)
571
572 if tracing_count == self._get_tracing_count():
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
606 # In this case we have not created variables on the first call. So we can
607 # run the first trace but we should fail if variables are created.
--> 608 results = self._stateful_fn(*args, **kwds)
609 if self._created_variables:
610 raise ValueError("Creating variables on a non-first call to a function"
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
2407 """Calls a graph function specialized to the inputs."""
2408 with self._lock:
-> 2409 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2410 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2411
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2765
2766 self._function_cache.missed.add(call_context_key)
-> 2767 graph_function = self._create_graph_function(args, kwargs)
2768 self._function_cache.primary[cache_key] = graph_function
2769 return graph_function, args, kwargs
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2655 arg_names=arg_names,
2656 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657 capture_by_value=self._capture_by_value),
2658 self._function_attributes,
2659 # Tell the ConcreteFunction to clean up its graph once it goes out of
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
437 # __wrapped__ allows AutoGraph to swap in a converted function. We give
438 # the function a weak reference to itself to avoid a reference cycle.
--> 439 return weak_wrapped_fn().__wrapped__(*args, **kwds)
440 weak_wrapped_fn = weakref.ref(wrapped_fn)
441
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
AssertionError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch *
with backend.eager_learning_phase_scope(0):
/usr/lib/python3.6/contextlib.py:81 __enter__
return next(self.gen)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
assert ops.executing_eagerly_outside_functions()
AssertionError:
回答1:
Ah, I believe this issue is coming from mismatched expectations on sample_batch
. TFF passes sample_batch
to Keras, which calls a forward pass with this sample batch to initialize various attributes of the keras model. sample_batch
should be either a sample from the literal data you are going to be feeding the model as on the server side, or a batch of fake data which matches the shape and type of the data you will be passing in.
An example of the former can be found here (this uses tf.data.Dataset
), and there are several examples of the latter in test code, like here.
From what I see of the definition of the model, likely the x
element of your sample_batch should be an ndarray
of shape [2, 100, 300, 3]
(where 2 is for the batch size, but technically this can be any nonzero dimension), and the y
element should also match the expected y
structure in the data you are using.
I hope this helps, just ping back if there are any problems!
One thing to note, that may be helpful in thinking about TFF--TFF is building a syntax tree representing the distributed computation you are defining via build_federated_averaging_process
. This error actually occurs during construction of this object. TFF must trace the computation you pass it in order to know what structure to generate, and this is what is raising here. Actual training of the model happens when you call next
on the returned IterativeProcess
.
回答2:
I have same problem: if I execute this line state, metrics = iterative_process.next(state, federated_train_data) print('round 1, metrics={}'.format(metrics))
I find this error InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset/_140]] (1) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] 0 successful operations. 0 derived errors ignored.
knowin that I employe VGG16 have you any idea on this type of error
来源:https://stackoverflow.com/questions/59622300/resnet-model-in-tensorflow-federated