I am trying to compute sum of large array in parallel with metal swift.
Is there a god way to do it?
My plane was that I divide my array to sub arrays, compute
The accepted answer is annoyingly missing the kernel that was written for it. The source is here, but here is the full program and shader that can be run as a swift command line application.
/*
* Command line Metal Compute Shader for data processing
*/
import Metal
import Foundation
//------------------------------------------------------------------------------
let count = 10_000_000
let elementsPerSum = 10_000
//------------------------------------------------------------------------------
typealias DataType = CInt // Data type, has to be the same as in the shader
//------------------------------------------------------------------------------
let device = MTLCreateSystemDefaultDevice()!
let library = device.makeDefaultLibrary()!
let parsum = library.makeFunction(name: "parsum")!
let pipeline = try! device.makeComputePipelineState(function: parsum)
//------------------------------------------------------------------------------
// Our data, randomly generated:
var data = (0...stride * count, options: [])!
// A buffer for individual results (zero initialized)
let resultsBuffer = device.makeBuffer(length: MemoryLayout.stride * resultsCount, options: [])!
// Our results in convenient form to compute the actual result later:
let pointer = resultsBuffer.contents().bindMemory(to: DataType.self, capacity: resultsCount)
let results = UnsafeBufferPointer(start: pointer, count: resultsCount)
//------------------------------------------------------------------------------
let queue = device.makeCommandQueue()!
let cmds = queue.makeCommandBuffer()!
let encoder = cmds.makeComputeCommandEncoder()!
//------------------------------------------------------------------------------
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(dataBuffer, offset: 0, index: 0)
encoder.setBytes(&dataCount, length: MemoryLayout.size, index: 1)
encoder.setBuffer(resultsBuffer, offset: 0, index: 2)
encoder.setBytes(&elementsPerSumC, length: MemoryLayout.size, index: 3)
//------------------------------------------------------------------------------
// We have to calculate the sum `resultCount` times => amount of threadgroups is `resultsCount` / `threadExecutionWidth` (rounded up) because each threadgroup will process `threadExecutionWidth` threads
let threadgroupsPerGrid = MTLSize(width: (resultsCount + pipeline.threadExecutionWidth - 1) / pipeline.threadExecutionWidth, height: 1, depth: 1)
// Here we set that each threadgroup should process `threadExecutionWidth` threads, the only important thing for performance is that this number is a multiple of `threadExecutionWidth` (here 1 times)
let threadsPerThreadgroup = MTLSize(width: pipeline.threadExecutionWidth, height: 1, depth: 1)
//------------------------------------------------------------------------------
encoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
encoder.endEncoding()
//------------------------------------------------------------------------------
var start, end : UInt64
var result : DataType = 0
//------------------------------------------------------------------------------
start = mach_absolute_time()
cmds.commit()
cmds.waitUntilCompleted()
for elem in results {
result += elem
}
end = mach_absolute_time()
//------------------------------------------------------------------------------
print("Metal result: \(result), time: \(Double(end - start) / Double(NSEC_PER_SEC))")
//------------------------------------------------------------------------------
result = 0
start = mach_absolute_time()
data.withUnsafeBufferPointer { buffer in
for elem in buffer {
result += elem
}
}
end = mach_absolute_time()
print("CPU result: \(result), time: \(Double(end - start) / Double(NSEC_PER_SEC))")
//------------------------------------------------------------------------------
#include
using namespace metal;
typedef unsigned int uint;
typedef int DataType;
kernel void parsum(const device DataType* data [[ buffer(0) ]],
const device uint& dataLength [[ buffer(1) ]],
device DataType* sums [[ buffer(2) ]],
const device uint& elementsPerSum [[ buffer(3) ]],
const uint tgPos [[ threadgroup_position_in_grid ]],
const uint tPerTg [[ threads_per_threadgroup ]],
const uint tPos [[ thread_position_in_threadgroup ]]) {
uint resultIndex = tgPos * tPerTg + tPos;
uint dataIndex = resultIndex * elementsPerSum; // Where the summation should begin
uint endIndex = dataIndex + elementsPerSum < dataLength ? dataIndex + elementsPerSum : dataLength; // The index where summation should end
for (; dataIndex < endIndex; dataIndex++)
sums[resultIndex] += data[dataIndex];
}