问题
My question is a follow-up to How to make this code faster (learning best practices)?, which has been put on hold (bummer). The problem is to optimize a loop over an array with floats which are tested for whether they lie within a given interval. Indices of matching elements in the array are to be stored in a provided result array.
The test includes two conditions (smaller than the upper threshold and bigger than the lower one). The obvious code for the test is if( elem <= upper && elem >= lower ) ...
. I observed that branching (including the implicit branch involved in the short-circuiting operator&&) is much more expensive than the second comparison. What I came up with is below. It is about 20%-40% faster than a naive implementation, more than I expected. It uses the fact that bool is an integer type. The condition test result is used as an index into two result arrays. Only one of them will contain the desired data, the other one can be discarded. This replaces program structure with data structure and computation.
I am interested in more ideas for optimization. "Technical hacks" (of the kind provided here) are welcome. I'm also interested in whether modern C++ could provide means to be faster, e.g. by enabling the compiler to create parallel running code. Think visitor pattern/functor. Computations on the single srcArr elements are almost independent, except that the order of indices in the result array depends on the order of testing the source array elements. I would loosen the requirements a little so that the order of the matching indices reported in the result array is irrelevant. Can anybody come up with a fast way?
Here is the source code of the function. A supporting main is below. gcc needs -std=c++11 because of chrono. VS 2013 express was able to compile this too (and created 40% faster code than gcc -O3).
#include <cstdlib>
#include <iostream>
#include <chrono>
using namespace std;
using namespace std::chrono;
/// Check all elements in srcArr whether they lie in
/// the interval [lower, upper]. Store the indices of
/// such elements in the array pointed to by destArr[1]
/// and return the number of matching elements found.
/// This has been highly optimized, mainly to avoid branches.
int findElemsInInterval( const float srcArr[], // contains candidates
int **const destArr, // two arrays to be filled with indices
const int arrLen, // length of each array
const float lower, const float upper // interval
)
{
// Instead of branching, use the condition
// as an index into two distinct arrays. We need to keep
// separate indices for both those arrays.
int destIndices[2];
destIndices[0] = destIndices[1] = 0;
for( int srcInd=0; srcInd<arrLen; ++srcInd )
{
// If the element is inside the interval, both conditions
// are true and therefore equal. In all other cases
// exactly one condition is true so that they are not equal.
// Matching elements' indices are therefore stored in destArr[1].
// destArr[0] is a kind of a dummy (it will incidentally contain
// indices of non-matching elements).
// This used to be (with a simple int *destArr)
// if( srcArr[srcInd] <= upper && srcArr[srcInd] >= lower) destArr[destIndex++] = srcInd;
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[isInInterval][destIndices[isInInterval]++] = srcInd;
}
return destIndices[1]; // the number of elements in the results array
}
int main(int argc, char *argv[])
{
int arrLen = 1000*1000*100;
if( argc > 1 ) arrLen = atol(argv[1]);
// destArr[1] will hold the indices of elements which
// are within the interval.
int *destArr[2];
// we don't check destination boundaries, so make them
// the same length as the source.
destArr[0] = new int[arrLen];
destArr[1] = new int[arrLen];
float *srcArr = new float[arrLen];
// Create always the same numbers for comparison (don't srand).
for( int srcInd=0; srcInd<arrLen; ++srcInd ) srcArr[srcInd] = rand();
// Create an interval in the middle of the rand() spectrum
float lowerLimit = RAND_MAX/3;
float upperLimit = lowerLimit*2;
cout << "lower = " << lowerLimit << ", upper = " << upperLimit << endl;
int numInterval;
auto t1 = high_resolution_clock::now(); // measure clock time as an approximation
// Call the function a few times to get a longer run time
for( int srcInd=0; srcInd<10; ++srcInd )
numInterval = findElemsInInterval( srcArr, destArr, arrLen, lowerLimit, upperLimit );
auto t2 = high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( t2 - t1 ).count();
cout << numInterval << " elements found in " << duration << " milliseconds. " << endl;
return 0;
}
回答1:
Thinking of the integer range check optimization of turning a <= x && x < b into ((unsigned)(x-a)) < b-a, a floating point variant comes to mind:
You could try something like
const float radius = (b-a)/2;
if( fabs( x-(a+radius) ) < radius )
...
to reduce the check to one conditional.
回答2:
I see about a 10% speedup from this:
int destIndex = 0; // replace destIndices
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[1][destIndex] = srcInd;
destIndex += isInInterval;
回答3:
Eliminate the pair of output arrays. Instead only advance the 'number written' by 1 if you want to keep the result, otherwise just keep overwriting the 'one past the end' index.
Ie, retval[destIndex]=curIndex; destIndex+= isInArray;
-- better coherancy and less wasted memory.
Write two versions: one that supports a fixed array length (of say 1024 or whatever) and another that supports a runtime parameter. Use a template
argumemt to remove code duplication. Assume the length is less than that constant.
Have the function return size and a RVO'd std::array<unsigned, 1024>
.
Write a wrapper function that merges results (create all results, then merge them). Then throw the parrallel patterns library at the problem (so the results get computed in parrallel).
回答4:
If you allow yourself vectorization using the SSE (or better, AVX) instruction set, you can perform 4/8 comparisons in a go, do this twice, 'and' the results, and retrieve the 4 results (-1 or 0). At the same time, this unrolls the loop.
// Preload the bounds
__m128 lo= _mm_set_ps(lower);
__m128 up= _mm_set_ps(upper);
int srcIndex, dstIndex= 0;
for (srcInd= 0; srcInd + 3 < arrLen; )
{
__m128 src= _mm_load_ps(&srcArr[srcInd]); // Load 4 values
__m128 tst= _mm_and_ps(_mm_cmple_ps(src, lo), _mm_cmpge_ps(src, up)); // Test
// Copy the 4 indexes with conditional incrementation
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[0];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[1];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[2];
dstArr[dstIndex]= srcInd++; destIndex-= tst.m128i_i32[3];
}
CAUTION: unchecked code.
来源:https://stackoverflow.com/questions/22832888/optimizing-a-comparison-over-array-elements-with-two-conditions-c-abstraction