I\'ve got a problem that I want to split across multiple CUDA devices, but I suspect my current system architecture is holding me back;
What I\'ve set up is a GPU class
What you need is a multi-threaded implementation of the map
built-in function. Here is one implementation. That, with a little modification to suit your particular needs, you get:
import threading
def cuda_map(args_list, gpu_instances):
result = [None] * len(args_list)
def task_wrapper(gpu_instance, task_indices):
for i in task_indices:
result[i] = gpu_instance.gpufunction(args_list[i])
threads = [threading.Thread(
target=task_wrapper,
args=(gpu_i, list(xrange(len(args_list)))[i::len(gpu_instances)])
) for i, gpu_i in enumerate(gpu_instances)]
for t in threads:
t.start()
for t in threads:
t.join()
return result
It is more or less the same as what you have above, with the big difference being that you don't spend time waiting for each single completion of the gpufunction
.
You need to get all your bananas lined up on the CUDA side of things first, then think about the best way to get this done in Python [shameless rep whoring, I know].
The CUDA multi-GPU model is pretty straightforward pre 4.0 - each GPU has its own context, and each context must be established by a different host thread. So the idea in pseudocode is:
In Python, this might look something like this:
import threading
from pycuda import driver
class gpuThread(threading.Thread):
def __init__(self, gpuid):
threading.Thread.__init__(self)
self.ctx = driver.Device(gpuid).make_context()
self.device = self.ctx.get_device()
def run(self):
print "%s has device %s, api version %s" \
% (self.getName(), self.device.name(), self.ctx.get_api_version())
# Profit!
def join(self):
self.ctx.detach()
threading.Thread.join(self)
driver.init()
ngpus = driver.Device.count()
for i in range(ngpus):
t = gpuThread(i)
t.start()
t.join()
This assumes it is safe to just establish a context without any checking of the device beforehand. Ideally you would check the compute mode to make sure it is safe to try, then use an exception handler in case a device is busy. But hopefully this gives the basic idea.