Kedro - how to pass nested parameters directly to node

旧巷老猫 提交于 2021-02-08 03:41:21

问题


kedro recommends storing parameters in conf/base/parameters.yml. Let's assume it looks like this:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

And now imagine I have some data_engineering pipeline whose nodes.py has a function that looks something like this:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py? I unsuccessfully tried:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

I know that I could just pass all parameters into the function by using ['parameters'] or just pass all model_params parameters with ['params:model_params'] but it seems unelegant and I feel like there must be a way. Would appreciate any input!


回答1:


(Disclaimer: I'm part of the Kedro team)

Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.

Probably the most viable solution would be to customise your ProjectContext (in src/<package_name>/run.py) class by overwriting _get_feed_dict method as follows:

class ProjectContext(KedroContext):
    # ...


    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.

            Example:

                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            """
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.



来源:https://stackoverflow.com/questions/61452211/kedro-how-to-pass-nested-parameters-directly-to-node

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!