How to sweep many hyperparameter sets in parallel in Python?

落花浮王杯 提交于 2021-01-28 11:24:48

问题


Note that I have to sweep through more argument sets than available CPUs, so I'm not sure if Python will automatically schedule the use of the CPUs depending on their availability or what.

Here is what I tried, but I get an error about the arguments:

import random
import multiprocessing
from train_nodes import run
import itertools

envs = ["AntBulletEnv-v0", "HalfCheetahBulletEnv-vo", "HopperBulletEnv-v0", "ReacherBulletEnv-v0",
        "Walker2DBulletEnv-v0", "InvertedDoublePendulumBulletEnv-v0"]
algs = ["PPO", "A2C"]
seeds = [random.randint(0, 200), random.randint(200, 400), random.randint(400, 600), random.randint(600, 800), random.randint(800, 1000)]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

with multiprocessing.Pool(num_cpus) as processing_pool:
    processing_pool.map(run, args)

run takes in 3 arguments: env, alg, and seed. For some reason here it doesn't register all 3.


回答1:


The function in multiprocessing.Pool.map expects one argument. One way to adapt your code is to write a small wrapper function that takes env, alg, and seed as one argument, separates them, and passes them to run.

Another option is to use multiprocessing.Pool.starmap, which allows multiple arguments to be passed to the function.

import random
import multiprocessing
import itertools

envs = [
    "AntBulletEnv-v0",
    "HalfCheetahBulletEnv-vo",
    "HopperBulletEnv-v0",
    "ReacherBulletEnv-v0",
    "Walker2DBulletEnv-v0",
    "InvertedDoublePendulumBulletEnv-v0",
]
algs = ["PPO", "A2C"]
seeds = [
    random.randint(0, 200),
    random.randint(200, 400),
    random.randint(400, 600),
    random.randint(600, 800),
    random.randint(800, 1000),
]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

# sample implementation or `run`
def run(env, alg, seed):
    # do stuff
    return random.randint(0, 200)

def wrapper(env_alg_seed):
    env, alg, seed = env_alg_seed
    return run(env=env, alg=alg, seed=seed)

# use a wrapper
with multiprocessing.Pool(num_cpus) as processing_pool:
    # accumulate results in a dictionary
    results = processing_pool.map(wrapper, args)

# use starmap and call `run` directly
with multiprocessing.Pool(num_cpus) as processing_pool:
    results = processing_pool.starmap(run, args)


来源:https://stackoverflow.com/questions/65694724/how-to-sweep-many-hyperparameter-sets-in-parallel-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!