Time for some Funky SQL: Prefix Sum Calculation


This Stack Overflow question has yet again nerd-sniped me

[finding the] maximum element in the array that would result from performing all M operations

Here’s the question by John that was looking for a Java solution:

With an array of N elements which are initialized to 0. we are given a sequence of M operations of the sort (p; q; r). The operation (p; q; r) signifies that the integer r should be added to all array elements A[p];A[p + 1]; : : : ;A[q]. You are to output the maximum element in the array that would result from performing all M operations. There is a naive solution that simply performs all operations and then returns the maximum value, that takes O(MN) time. We are looking for a more efficient algorithm.

Interesting. Indeed, a naive solution would just perform all the operations as requested. Another naive but less naive solution would transform the operations into signals of the form (x; y) for all (p; r) and for all (q + 1; -r). In other words, we could implement the solution I had presented trivially as such:

// This is just a utility class to model the ops
class Operation {
    final int p;
    final int q;
    final int r;

    Operation(int p, int q, int r) {
        this.p = p;
        this.q = q;
        this.r = r;
    }
}

// These are some example ops
Operation[] ops = {
    new Operation(4, 12, 2),
    new Operation(2,  8, 3),
    new Operation(6,  7, 1),
    new Operation(3,  7, 2)
};

// Here, we're calculating the min and max
// values for the combined values of p and q
IntSummaryStatistics stats = Stream
    .of(ops)
    .flatMapToInt(op -> IntStream.of(op.p, op.q))
    .summaryStatistics();

// Create an array for all the required elements using
// the min value as "offset"
int[] array = new int[stats.getMax() - stats.getMin()];

// Put +r and -r "signals" into the array for each op
for (Operation op : ops) {
    int lo = op.p     - stats.getMin();
    int hi = op.q + 1 - stats.getMin();

    if (lo >= 0)
        array[lo] = array[lo] + op.r;

    if (hi < array.length)
        array[hi] = array[hi] - op.r;
}

// Now, calculate the prefix sum sequentially in a
// trivial loop
int maxIndex = Integer.MIN_VALUE;
int maxR = Integer.MIN_VALUE;
int r = 0;

for (int i = 0; i < array.length; i++) {
    r = r + array[i];
    System.out.println((i + stats.getMin()) + ":" + r);

    if (r > maxR) {
        maxIndex = i + stats.getMin();
        maxR = r;
    }
}

System.out.println("---");
System.out.println(maxIndex + ":" + maxR);

The above program would print out:

2:3
3:5
4:7
5:7
6:8
7:8
8:5
9:2
10:2
11:2
---
6:8

So, the maximum value is generated at position 6, and the value is 8.

Faster calculation in Java 8

This can be calculated faster using Java 8’s new Arrays.parallelPrefix() operation. Instead of the loop in the end, just write:

Arrays.parallelPrefix(array, Integer::sum);
System.out.println(
    Arrays.stream(array).parallel().max());

Which is awesome, as it can run faster than the sequential O(M+N) solution. Read up about prefix sums here.

Now show me the promised SQL code

In SQL, the naive sequential and linear complexity solution can easily be re-implemented, and I’m showing a solution for PostgreSQL.

How can we do it? We’re using a couple of features here. First off, we’re using common table expressions (also known as the WITH clause). We’re using these to declare table variables. The first variable is the op table, which contains our operation instructions, like in Java:

WITH 
  op (p, q, r) AS (
    VALUES
      (4, 12, 2),
      (2,  8, 3),
      (6,  7, 1),
      (3,  7, 2)
  ),
  ...

This is trivial. We’re essentially just generating a couple of example values.

The second table variable is the signal table, where we use the previously described optimisation of putting a +r signal at all p positions, and a -r signal at all q + 1 positions:

WITH 
  ...,
  signal(x, r) AS (
    SELECT p, r
    FROM op
    UNION ALL
    SELECT q + 1, -r
    FROM op
  )
...

When you run

SELECT * FROM signal ORDER BY x

you would simply get:

x   r
------
2   3
3   2
4   2
6   1
8  -2
8  -1
9  -3
13 -2

All we need to do now is calculate a running total (which is essentially the same as a prefix sum) as follows:

SELECT x, SUM(r) OVER (ORDER BY x)
FROM signal 
ORDER BY x
x   r
------
2   3
3   5
4   7
6   8
8   5
8   5
9   2
13  0

Now just find the max value for r, and we’re all set. We’ll take the shortcut by using ORDER BY and LIMIT:

SELECT x, SUM(r) OVER (ORDER BY x) AS s
FROM signal 
ORDER BY s DESC
LIMIT 1

And we’re back with:

x   r
------
6   8

Perfect! Here’s the full query:

WITH 
  op (p, q, r) AS (
    VALUES
      (4, 12, 2),
      (2,  8, 3),
      (6,  7, 1),
      (3,  7, 2)
  ),
  signal(x, r) AS (
    SELECT p, r
    FROM op
    UNION ALL
    SELECT q + 1, -r
    FROM op
  )
SELECT x, SUM(r) OVER (ORDER BY x) AS s
FROM signal 
ORDER BY s DESC
LIMIT 1

Can you beat the conciseness of this SQL solution? I bet you can’t. Challengers shall write alternatives in the comment section.

Thrilled about the SQL here? Read about how to calculate a subset sum in Oracle SQL.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s