ResNet model in Tensorflow Federated

孤街浪徒 提交于 2020-02-05 14:38:09

问题


I tried to customize the model in "Image classification" tutorial in Tensorflow Federated. (It originally used a sequential model) I use Keras ResNet50 but when it began to train, there is always an error "Incompatible shapes"

Here are my codes:

NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5

def create_compiled_keras_model():
  model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', 
                                                input_tensor=tf.keras.layers.Input(shape=(100, 
                                                300, 3)), pooling=None)

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model


def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

Error information: enter image description here

I feel that the shape is incompatible because the epoch and clients information were somehow missing. Would be very thankful if someone could give me a hint.

Updates:

The Assertion error happened during tff.learning.build_federated_averaging_process

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
      2 
      3 # iterative_process = build_federated_averaging_process(model_fn)

13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    165   return optimizer_utils.build_model_delta_optimizer_process(
    166       model_fn, client_fed_avg, server_optimizer_fn,
--> 167       stateful_delta_aggregate_fn, stateful_model_broadcast_fn)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    349   # still need this.
    350   with tf.Graph().as_default():
--> 351     dummy_model_for_metadata = model_utils.enhance(model_fn())
    352 
    353   # ===========================================================================

<ipython-input-159-b2763ace8e5b> in model_fn()
      1 def model_fn():
      2   keras_model = model
----> 3   return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
    211   # Model.test_on_batch() once before asking for metrics.
    212   if isinstance(dummy_tensors, collections.Mapping):
--> 213     keras_model.test_on_batch(**dummy_tensors)
    214   else:
    215     keras_model.test_on_batch(*dummy_tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1007         sample_weight=sample_weight,
   1008         reset_metrics=reset_metrics,
-> 1009         standalone=True)
   1010     outputs = (
   1011         outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
    503       y,
    504       sample_weights=sample_weights,
--> 505       output_loss_metrics=model._output_loss_metrics)
    506 
    507   if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    568         xla_context.Exit()
    569     else:
--> 570       result = self._call(*args, **kwds)
    571 
    572     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    606       # In this case we have not created variables on the first call. So we can
    607       # run the first trace but we should fail if variables are created.
--> 608       results = self._stateful_fn(*args, **kwds)
    609       if self._created_variables:
    610         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2407     """Calls a graph function specialized to the inputs."""
   2408     with self._lock:
-> 2409       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2410     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2411 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2765 
   2766       self._function_cache.missed.add(call_context_key)
-> 2767       graph_function = self._create_graph_function(args, kwargs)
   2768       self._function_cache.primary[cache_key] = graph_function
   2769       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2655             arg_names=arg_names,
   2656             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657             capture_by_value=self._capture_by_value),
   2658         self._function_attributes,
   2659         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

AssertionError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch  *
        with backend.eager_learning_phase_scope(0):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
        assert ops.executing_eagerly_outside_functions()

    AssertionError: 


回答1:


Ah, I believe this issue is coming from mismatched expectations on sample_batch. TFF passes sample_batch to Keras, which calls a forward pass with this sample batch to initialize various attributes of the keras model. sample_batch should be either a sample from the literal data you are going to be feeding the model as on the server side, or a batch of fake data which matches the shape and type of the data you will be passing in.

An example of the former can be found here (this uses tf.data.Dataset), and there are several examples of the latter in test code, like here.

From what I see of the definition of the model, likely the x element of your sample_batch should be an ndarray of shape [2, 100, 300, 3] (where 2 is for the batch size, but technically this can be any nonzero dimension), and the y element should also match the expected y structure in the data you are using.

I hope this helps, just ping back if there are any problems!

One thing to note, that may be helpful in thinking about TFF--TFF is building a syntax tree representing the distributed computation you are defining via build_federated_averaging_process. This error actually occurs during construction of this object. TFF must trace the computation you pass it in order to know what structure to generate, and this is what is raising here. Actual training of the model happens when you call next on the returned IterativeProcess.




回答2:


I have same problem: if I execute this line state, metrics = iterative_process.next(state, federated_train_data) print('round 1, metrics={}'.format(metrics))

I find this error InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset/_140]] (1) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] 0 successful operations. 0 derived errors ignored.

knowin that I employe VGG16 have you any idea on this type of error



来源:https://stackoverflow.com/questions/59622300/resnet-model-in-tensorflow-federated

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!