Kedro - how to pass nested parameters directly to node

旧巷老猫 提交于 2021-02-08 03:41:21


kedro recommends storing parameters in conf/base/parameters.yml. Let's assume it looks like this:

step_size: 1
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

And now imagine I have some data_engineering pipeline whose has a function that looks something like this:

def some_pipeline_step(num_train_steps):
    Takes the parameter `num_train_steps` as argument.

How would I go about and pass that nested parameters straight to this function in data_engineering/ I unsuccessfully tried:

from kedro.pipeline import Pipeline, node

from .nodes import split_data

def create_pipeline(**kwargs):
    return Pipeline(

I know that I could just pass all parameters into the function by using ['parameters'] or just pass all model_params parameters with ['params:model_params'] but it seems unelegant and I feel like there must be a way. Would appreciate any input!


(Disclaimer: I'm part of the Kedro team)

Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.

Probably the most viable solution would be to customise your ProjectContext (in src/<package_name>/ class by overwriting _get_feed_dict method as follows:

class ProjectContext(KedroContext):
    # ...

    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.


                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.

