Сравнение двух значений в форме (a + sqrt (b)) как можно быстрее?

Как часть программы, которую я пишу, мне нужно сравнить два значения в форме a + sqrt(b) где a и b - целые числа без знака. Поскольку это часть узкого цикла, я бы хотел, чтобы это сравнение выполнялось как можно быстрее. (Если это имеет значение, я запускаю код на компьютерах с архитектурой x86-64, а целые числа без знака не превышают 10 ^ 6. Кроме того, я точно знаю, что a1<a2.)

Это отдельная функция, которую я пытаюсь оптимизировать. Мои числа являются достаточно маленькими целыми числами, чтобы double (или даже с float) могли точно представлять их, но ошибка округления в результатах sqrt не должна изменить результат.

// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

Тестовый пример: is_smaller(900000, 1000000, 900001, 998002) должен возвращать true, но, как показано в комментариях @wim, вычисление его с помощью sqrtf() вернет false. Так бы (int)sqrt() усечь обратно до целого числа.

a1+sqrt(b1) = 90100 и a2+sqrt(b2) = 901000.00050050037512481206. Ближайший к этому поплавок - ровно 90100.


Поскольку функция sqrt() как правило, довольно дорогая даже на современном x86-64, когда она полностью встроена как инструкция sqrtsd, я стараюсь по возможности избегать вызова sqrt().

Удаление sqrt путем возведения в квадрат потенциально также позволяет избежать опасности ошибок округления, делая все вычисления точными.

Если бы вместо этого функция была что-то вроде этого...

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

... тогда я мог бы просто сделать return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

Но теперь, так как есть два sqrt(...) члена, я не могу сделать одно и то же алгебраическое манипулирование.

Я мог бы возвести в квадрат значения дважды, используя эту формулу:

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

Беззнаковое деление на 4 дешево, потому что это просто битовое смещение, но так как я возводю числа в квадрат дважды, мне нужно будет использовать 128-битные целые числа, и мне нужно будет ввести несколько проверок >=0 (потому что я сравниваю неравенство вместо равенство).

Такое ощущение, что может быть способ сделать это быстрее, применив лучшую алгебру к этой проблеме. Есть ли способ сделать это быстрее?

Ответ 1

Здесь версия без sqrt, хотя я не уверен, что она быстрее, чем версия, которая имеет только один sqrt (это может зависеть от распределения значений).

Вот математика (как убрать оба sqrts):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

Здесь правая сторона всегда отрицательна. Если левая сторона положительна, то мы должны вернуть true.

Если левая сторона отрицательна, то мы можем выровнять неравенство:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

Здесь следует обратить внимание на то, что если a2>=a1+1000, то is_smaller всегда возвращает true (поскольку максимальное значение sqrt(b1) равно 1000). Если a2<=a1+1000, то ad - это небольшое число, поэтому ad^4 всегда будет вписываться в 64-битную (нет необходимости в 128-битной арифметике). Вот код:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

РЕДАКТИРОВАТЬ: Как заметил Питер Кордес, первый, if не нужно, как второй, если обрабатывает его, поэтому код становится меньше и быстрее:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

Ответ 2

Я устал и, вероятно, ошибся; но я уверен, что если я это сделаю, кто-то укажет на это..

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

Если вы знаете a1 < a2 тогда оно может стать:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}

Ответ 3

Существует также метод Ньютона для вычисления целочисленных квадратов, как описано здесь. Другой подход заключается в том, чтобы не вычислять квадратный корень, а искать пол (sqrt (n)) с помощью бинарного поиска... "всего" 1000 полных квадратных чисел меньше 10 ^ 6. Это, вероятно, имеет плохую производительность, но будет интересным подходом. Я не измерял ни одного из них, но вот примеры:

#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}

Ответ 4

Я не уверен, что алгебраические манипуляции в сочетании с целочисленной арифметикой обязательно приведут к быстрейшему решению. В этом случае вам понадобится много скалярных умножений (что не очень быстро), и/или предсказание ветвления может потерпеть неудачу, что может снизить производительность. Очевидно, что вам придется тестировать, чтобы увидеть, какое решение является самым быстрым в вашем конкретном случае.

Один из способов сделать sqrt немного быстрее - добавить -fno-math-errno в gcc или clang. В этом случае компилятору не нужно проверять наличие отрицательных входных данных. С icc это настройка по умолчанию.

Еще большее улучшение производительности возможно при использовании векторизованной инструкции sqrt sqrtpd вместо скалярной инструкции sqrt sqrtsd. Питер Кордес показал, что clang может автоматически векторизовать этот код, так что он генерирует этот sqrtpd.

Однако успешность автоматической векторизации в значительной степени зависит от правильных настроек компилятора и используемого компилятора (clang, gcc, icc и т.д.). С -march=nehalem или старше, clang не векторизируется.

Более надежные результаты векторизации возможны с помощью следующего встроенного кода, см. Ниже. Для переносимости мы предполагаем только поддержку SSE2, что является базовым уровнем x86-64.

/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}


Смотрите эту ссылку Godbolt для созданной сборки.

В простом тесте пропускной способности на Intel Skylake с параметрами компилятора gcc -m64 -O3 -fno-math-errno -march=nehalem я обнаружил пропускную способность is_smaller_v5() которая была в 2,6 раза лучше, чем исходная is_smaller(): 6,8 тактов процессора против 18 циклов процессора, включая накладные расходы на цикл. Однако в (слишком?) Простом тесте задержки, где входы a1, a2, b1, b2 зависели от результата предыдущего is_smaller(_v5), я не увидел никаких улучшений. (39,7 цикла против 39 циклов).

Ответ 5

Вы могли бы сделать что-то вроде

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    // a2 - a1 > sqrt(b1) - sqrt(b2)
    // this is true if b1 < b2
    if (b1 < b2) return true;

    unsigned t = a2 - a1;
    unsigned tt = t*t;    
    // if sqrt(b1) < t, the statement is true
    // thus b1 < t*t would do the trick
    // for 32bit, this might overflow
    // this is equivlanet to testing  b1/t < t (integer wise)
    // or might be equilvalent to testing (b1/t)/t == 0
    if (b1 < tt) return true;

    // the final test is sqrt(b1) < t + sqrt(b2)
    // we know that t + sqrt(b2) is positive
    // b1 < t*t + 2 t sqrt(b2) + b2
    // (b1 - b2 - t*t) < sqrt(4 * t*t * b2)
    t = b1 - b2 - tt;
    if (t <= 0 ) return true;

    // (b1 - b2 - t*t)^2 < 4 * t*t * b2
    // might overflow
    return t < sqrt(4*b2*tt);
}

Ответ 6

Возможно, не лучше, чем другие ответы, но использует другую идею (и массу предварительного анализа).

// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(У меня нет удобного компилятора, поэтому он, вероятно, содержит опечатку или два.)

Ответ 7

У меня нет всей информации. Поэтому мой ответ может быть неуместным.

Но если у вас есть FPU в системе X86 и компилятор, который понимает директиву asm, тогда вы можете использовать инструкцию ассемблера fsqrt. Это обычно довольно быстро. По крайней мере, намного быстрее, чем стандартная библиотека.

что-то вроде:

inline float sqrt_asm(float f)
{
    __asm {
        fld f
            fsqrt
    }
}

Может быть, это решение уже достаточно быстро?

Может быть, вы могли бы взглянуть на

Самый быстрый квадратный корень

Вавилонский метод также стоит попробовать. Но это зависит от остальной части вашего кода и требований. , ,