Wednesday, July 6, 2011

Optimizing the product of prime powers

Remember how we generated the factorial from the product of its prime factors?
Let's try and optimize the product.
First trick is to use a Karatsuba multiplication, because multiplication is the major contributor in execution time.
The Squeak variant takes care to evaluate Karatsuba only on well balanced operands, and replace the recursion by an iteration if not well balanced:

LargePositiveInteger>>karatsubaTimes: anInteger
    "eventually use Karatsuba algorithm to perform the multiplication"
   
    | half xHigh xLow yHigh yLow low high mid xLen yLen |
    (anInteger isLargeEnoughForKaratsuba
        and: [self isLargeEnoughForKaratsuba]) ifFalse: [^self timesInteger: anInteger].
   
    "Check if length ratio is more than 2, and engage a loop
    to operate on integers with well balanced lengths.
    Note that we only add overhead at this level,
    but we hope to gain in lower level recursion"
    (xLen := self digitLength) >= (yLen := anInteger digitLength)
        ifTrue: [(half := xLen bitShift: -1) >= yLen
            ifTrue: [^(0 to: xLen by: yLen) detectSum: [:yShift |
                ((self copyDigitsFrom: yShift + 1 to: yShift + yLen)
                    karatsubaTimes: anInteger)
                        bitShift: 8 * yShift]]]
        ifFalse: [(half := yLen bitShift: -1) >= xLen
            ifTrue: [^(0 to: yLen by: xLen) detectSum: [:xShift |
                (self karatsubaTimes:
                    (anInteger copyDigitsFrom: xShift + 1 to: xShift + xLen))
                        bitShift: 8 * xShift]]].
   
    "At this point, lengths are well balanced, divide each integer in two halves"
    xHigh := self bitShift: -8 * half.
    xLow := self lowestNDigits: half.
    yHigh := anInteger bitShift: -8 * half.
    yLow := anInteger lowestNDigits: half.
   
    "Karatsuba trick: perform with 3 multiplications instead of 4"
    low := xLow karatsubaTimes: yLow.
    high := xHigh karatsubaTimes: yHigh.
    mid := (xHigh + xLow karatsubaTimes: yHigh + yLow) - (low + high).
   
    "Sum the parts of decomposition"
    ^low + (mid bitShift: 8*half) + (high bitShift: 16*half)
  
With a bit of tuning for the COG VM:

LargePositiveInteger>>isLargeEnoughForKaratsuba
        ^self digitLength >= 160


SmallInteger is not large enough for Karatsuba and falls back to regular multiplication *, we omit the code here. 
The copyDigitsFrom:to: and lowestNDigits: are just using a primitive for splitting a LargePositiveInteger, they are omitted too.
We'd also better optimize squared which is a simple case:

LargePositiveInteger>>squared
    "Eventually use a divide and conquer algorithm to perform the multiplication"
   
    | half xHigh xLow low high mid |
    self isLargeEnoughForKaratsuba ifFalse: [^self * self].
   
    "Divide digits in two halves"
    half := self digitLength bitShift: -1.
    xHigh := self bitShift: -8 * half.
    xLow := self lowestNDigits: half.
   
    "Use Karatsuba"
    low := xLow squared.
    high := xHigh squared.
    mid := xLow karatsubaTimes: xHigh.
   
    "Sum the parts of decomposition"
    ^low + (mid bitShift: 8*half+1) + (high bitShift: 16*half)

Then use squared in raisedToInteger:

Number>>raisedToInteger: anInteger
    | bitProbe result |

    anInteger negative ifTrue: [^(self raisedToInteger: anInteger negated) reciprocal].
    bitProbe := 1 bitShift: anInteger highBit - 1.
    result := self class one.
     [
        (anInteger bitAnd: bitProbe) = 0 ifFalse: [result := result * self].
        bitProbe := bitProbe bitShift: -1.
        bitProbe > 0 ]
    whileTrue: [result := result squared].
   
    ^result

We have a reasonable implementation for Karatsuba, but we cannot use a naïve:


Integer>>karatsubaPrimeFactorFactorial

    "Recompose the factorial from the prime factors, knowing the power of each prime."
    ^((Integer primesUpTo: self + 1) collect: [:p |
        p raisedToInteger: self - (self sumDigitsInBase: p) // (p - 1)]) karatsubaProduct

Collection>>karatsubaProduct
    ^self inject: 1 into: [:product :element | product karatsubaTimes: element]
Karatsuba is just a drag in this case because we always multiply a large integer with a small one. The only advantage comes from using squared in raisedToInteger:, and instead of 19136ms, we get:

Smalltalk garbageCollect.
Time millisecondsToRun: [1 to: 3000 do: [:x | x karatsubaPrimeFactorFactorial]].
->  18408 18503

So the second trick is essential: evaluate the product of terms in a divide and conquer fashion, dividing the numbers to multiply in two groups and recursively:

SequenceableCollection>>karatsubaProduct
    "Compute the product of self elements, using Karatsuba multiplication"
    self isEmpty ifTrue: [^1].
    ^self karatsubaProductFrom: 1 to: self size by: 1

SequenceableCollection>>karatsubaProductFrom: startIndex to: stopIndex by: inc
    | nextInc nextIndex |
    (nextIndex := startIndex + inc) > stopIndex ifTrue: [^self at: startIndex].
    nextInc := inc * 2.
    ^(self karatsubaProductFrom: startIndex to: stopIndex by: nextInc) karatsubaTimes:
      (self karatsubaProductFrom: nextIndex to: stopIndex by: nextInc)


Smalltalk garbageCollect.
Time millisecondsToRun: [1 to: 3000 do: [:x | x karatsubaPrimeFactorFactorial]].
-> 12627 12539

Better, but there is room toward the prime swing variant which performs the job in less than 7000ms. Let's adapt our strategy and use a third trick: try and square as many LargeInteger as possible. One idea is to group the terms having same powers, multiply them together, then raise to the prescribed power. But if term p is raised to the power of 7, we will have to evaluate (p squared squared) * p squared * p. So we will store p into three collections,
  • the collection of terms raisedTo: 1;
  • the collection of terms raisedTo: 2;
  • the collection of terms raisedTo: 4.
The last point, is that we won't bother with powers of 2, because the fastest variants also don't. Here we come with this algorithm:

Integer>>optimizedPrimeFactorFactorial
    "This is the optimized version of primeFactorFactorial"

    | powers primes |
    self <= 1 ifTrue: [^1].
    primes := (Integer primesUpTo: self + 1) allButFirst.
    powers := primes collect: [:e | self - (self sumDigitsInBase: e) // (e - 1)].
    ^(primes karatsubaProductPowers: powers) bitShift: self - self bitCount

SequenceableCollection>>karatsubaProductPowers: powers
    "Compute the product of self elements, each raised to the corresponding powers"
    | lastPower rank terms |
    (lastPower := powers size) = 0 ifTrue: [^1].
    rank := 1.
    terms := OrderedCollection new: (powers at: 1) highBit.
    [rank <= (powers at: 1)]
        whileTrue:
            [ [ rank > (powers at: lastPower) ] whileTrue: [lastPower := lastPower - 1].
            terms add: ((((1 to: lastPower) select: [:e | ((powers at: e) bitAnd: rank) ~= 0]) collect: [:i | self at: i]) karatsubaProduct raisedToInteger: rank).
            rank := rank bitShift: 1].
    ^terms karatsubaProduct

Smalltalk garbageCollect.
Time millisecondsToRun: [1 to: 3000 do: [:x | x optimizedPrimeFactorFactorial]].
->  8436 8433

The optimized prime factor algorithm is about 2.25 times faster than the naive prime factors on small integers up to 3000.
Let's instrument the algorithm to check if the terms are well balanced. We add the line
    (terms collect: [:e | e highBit]) inspect.
and check 50000 optimizedPrimeFactorFactorial
 -> an OrderedCollection(49775 49532 49490 49130 49165 48285 49407 48651 41929 40562 36299 26336 38838 55004 25969)

It's remarkable, that we decomposed the factorial of 50000 into terms having about 50000 bits!

The score on a larger integer :

[50000 factorial] timeToRun. -> 10288
[50000 primeFactorFactorial] timeToRun. -> 8853
[50000 optimizedPrimeFactorFactorial] timeToRun. -> 1263
[50000 primeSwingFactorial] timeToRun. -> 1004

Now a challenge: write optimizedPrimeFactorFactorial with Xtreams instead of SequenceableCollection!

No comments:

Post a Comment