Алгоритм поиска двумерных пиков в O (n) худшем случае?

Я делал этот курс по алгоритмам MIT. В первой лекции профессор представляет следующую проблему: -

Пик в 2D-массиве - это значение, такое, что все его 4 соседства меньше или равны ему, т.е. для

a[i][j] - локальный максимум,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

Теперь, учитывая массив NxN 2D, найдите пик в массиве.

Этот вопрос можно легко решить в O(N^2) раз, итерируя все элементы и вернув пик.

Однако он может быть оптимизирован для решения в O(NlogN) времени с помощью решения divide и conquer, как описано здесь.

Но они сказали, что существует алгоритм времени O(N), который решает эту проблему. Пожалуйста, предложите, как решить эту проблему в O(N).

PS (для тех, кто знает python). Сотрудники курса объяснили подход здесь (проблема 1-5. Пиковое определение доказательств) и также предоставил некоторый код python в своих наборах проблем. Но описанный подход совершенно неочевиден и очень трудно расшифровать. Код python в равной степени запутан. Поэтому я скопировал основную часть кода ниже для тех, кто знает python и может определить, какой алгоритм используется из кода.

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer

Ответ 1

  • Предположим, что ширина массива больше высоты, иначе мы разделимся в другом направлении.
  • Разделите массив на три части: центральный столбец, слева и справа.
  • Пройдите через центральный столбец и два соседних столбца и ищите максимум.
    • Если это в центральном столбце - это наш пик
    • Если это в левой части, запустите этот алгоритм на подмассиве left_side + central_column
    • Если это в правой части, запустите этот алгоритм на подмассиве right_side + central_column

Почему это работает:

В случаях, когда максимальный элемент находится в центральном столбце, очевидно. Если это не так, мы можем перейти от этого максимума к возрастающим элементам и определенно не пересечь центральную строку, поэтому пик определенно будет существовать в соответствующей половине.

Почему это O (n):

шаг # 3 принимает меньше или равно max_dimension итераций и max_dimension по крайней мере половин на каждом из двух шагов алгоритма. Это дает n+n/2+n/4+..., который равен O(n). Важная деталь: мы разделяем максимальное направление. Для квадратных массивов это означает, что разделенные направления будут чередоваться. Это отличие от последней попытки PDF, с которой вы связаны.

Примечание. Я не уверен, что он точно соответствует алгоритму в коде, который вы дали, это может быть или не быть другим.

Ответ 2

Вот рабочий код Java, который реализует алгоритм @maxim1000. Следующий код находит пик в 2D-массиве в линейном времени.

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

Ответ 3

Чтобы увидеть это (n):

Шаг расчета на картинке

Чтобы увидеть реализацию алгоритма:

1) начните с 1a) или 1b)

1a) установить левую половину, разделитель, правую половину.

1b) установить верхнюю половину, разделитель, нижнюю половину.

2) Найти глобальный максимум на делителе. [theta n]

3) Найти значения своего соседа. И запишите самый большой узел, который когда-либо посещался, как лучший узел. [тета 1]

# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
    bestSeen = neighbor
    if not trace is None: trace.setBestSeen(bestSeen)

4) проверьте, больше ли глобальный максимум, чем bestSeen и его сосед. [тета 1]

//Шаг 4 - главный ключ того, почему этот алгоритм работает

# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
    if not trace is None: trace.foundPeak(bestLoc)
    return bestLoc

5) Если 4) Истина, вернуть глобальный максимум в виде двумерного пика.

Иначе, если на этот раз 1а), выберите половину BestSeen, вернитесь к шагу 1b)

Иначе, выберите половину BestSeen, вернитесь к шагу 1a)


Чтобы наглядно увидеть, почему этот алгоритм работает, это все равно что захватить сторону наибольшего значения, продолжать сокращать границы и в конечном итоге получить значение BestSeen.

# Визуализированное моделирование

Round1

round2

round3

round4

round5

round6

наконец

Для этой матрицы 10 * 10 мы использовали только 6 шагов для поиска двумерного пика, и это вполне убедительно, что это действительно тета н


Сокол