问题
kedro recommends storing parameters in conf/base/parameters.yml
. Let's assume it looks like this:
step_size: 1
model_params:
learning_rate: 0.01
test_data_ratio: 0.2
num_train_steps: 10000
And now imagine I have some data_engineering
pipeline whose nodes.py
has a function that looks something like this:
def some_pipeline_step(num_train_steps):
"""
Takes the parameter `num_train_steps` as argument.
"""
pass
How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py
? I unsuccessfully tried:
from kedro.pipeline import Pipeline, node
from .nodes import split_data
def create_pipeline(**kwargs):
return Pipeline(
[
node(
some_pipeline_step,
["params:model_params.num_train_steps"],
dict(
train_x="train_x",
train_y="train_y",
),
)
]
)
I know that I could just pass all parameters into the function by using ['parameters']
or just pass all model_params
parameters with ['params:model_params']
but it seems unelegant and I feel like there must be a way. Would appreciate any input!
回答1:
(Disclaimer: I'm part of the Kedro team)
Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.
Probably the most viable solution would be to customise your ProjectContext
(in src/<package_name>/run.py
) class by overwriting _get_feed_dict
method as follows:
class ProjectContext(KedroContext):
# ...
def _get_feed_dict(self) -> Dict[str, Any]:
"""Get parameters and return the feed dictionary."""
params = self.params
feed_dict = {"parameters": params}
def _add_param_to_feed_dict(param_name, param_value):
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Example:
>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = "params:{}".format(param_name)
feed_dict[key] = param_value
if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict("{}.{}".format(param_name, key), val)
for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)
return feed_dict
Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.
来源:https://stackoverflow.com/questions/61452211/kedro-how-to-pass-nested-parameters-directly-to-node