Python の cProfile や line_profile でプロファイルを取って遅いコードを改善する

python

Python でヒープソートを実装したのだが、ビルトインの heapq を用いたものと比べて遥かに遅い。

class Heap:
    def __init__(self, values, compare):
        self.heap = values
        self.compare = compare
        for i in reversed(range(len(self.heap) // 2)):
            self.__down(i)

    def push(self, value):
        self.heap.append(value)
        self.__up(len(self.heap) - 1)

    def delete(self, idx):
        if idx >= len(self.heap):
            return
        self.heap[idx], self.heap[len(self.heap) - 1] = \
            self.heap[len(self.heap) - 1], self.heap[idx]
        self.heap.pop()
        self.__down(idx)
        self.__up(idx)

    def pop(self):
        if len(self.heap) == 0:
            return None
        ret = self.heap[0]
        self.delete(0)
        return ret

    def __parent_idx(self, idx):
        return (idx - 1) // 2 if idx > 0 else - 1

    def __left_child_idx(self, idx):
        return (idx * 2) + 1 if (idx * 2) + 1 < len(self.heap) else - 1

    def __right_child_idx(self, idx):
        return (idx * 2) + 2 if (idx * 2) + 2 < len(self.heap) else - 1

    def __up(self, idx):
        while idx > 0:
            if self.compare(self.heap[self.__parent_idx(idx)], self.heap[idx]):
                return
            self.heap[idx], self.heap[self.__parent_idx(idx)] = \
                self.heap[self.__parent_idx(idx)], self.heap[idx]
            idx = self.__parent_idx(idx)

    def __down(self, idx):
        while idx < len(self.heap):
            minIdx = self.__left_child_idx(idx)
            if minIdx == -1:
                return
            if self.__right_child_idx(idx) != -1 and \
                self.compare(
                    self.heap[self.__right_child_idx(idx)],
                    self.heap[minIdx]
                ):
                    minIdx = self.__right_child_idx(idx)

            if self.compare(self.heap[idx], self.heap[minIdx]):
                return

            self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
            idx = minIdx

def my_heap_sort(data):
    minheap = Heap([], lambda a, b: a < b)
    for d in data:
        minheap.push(d)
    ret = []
    for i in range(len(data)):
        ret.append(minheap.pop())
    return ret

def builtin_heapq_sort(data):
    heap = []
    for d in data:
        heapq.heappush(heap, d)
    ret = []
    for i in range(len(heap)):
        ret.append(heapq.heappop(heap))
    return ret

まずは cProifile でプロファイルを取ってみる。

import cProfile
import random

...

if __name__ == "__main__":
    x = list(range(10000))
    random.Random(5).shuffle(x)
    assert my_heap_sort(x) == builtin_heapq_sort(x)
    cProfile.run('my_heap_sort(x)')
    cProfile.run('builtin_heapq_sort(x)')

次のような出力が得られた。

$ python --version             
Python 3.11.4

$ python main.py
1330480 function calls in 0.258 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.258    0.258 <string>:1(<module>)
    10000    0.009    0.000    0.227    0.000 main.py:12(delete)
        1    0.000    0.000    0.000    0.000 main.py:2(__init__)
    10000    0.003    0.000    0.230    0.000 main.py:21(pop)
    61063    0.005    0.000    0.005    0.000 main.py:28(__parent_idx)
   116786    0.022    0.000    0.026    0.000 main.py:31(__left_child_idx)
   269835    0.050    0.000    0.059    0.000 main.py:34(__right_child_idx)
    20000    0.013    0.000    0.020    0.000 main.py:37(__up)
    10000    0.115    0.000    0.216    0.000 main.py:45(__down)
        1    0.004    0.004    0.258    0.258 main.py:63(my_heap_sort)
   239381    0.013    0.000    0.013    0.000 main.py:64(<lambda>)
    10000    0.004    0.000    0.024    0.000 main.py:8(push)
        1    0.000    0.000    0.258    0.258 {built-in method builtins.exec}
   553410    0.019    0.000    0.019    0.000 {built-in method builtins.len}
    20000    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    10000    0.000    0.000    0.000    0.000 {method 'pop' of 'list' objects}

         30005 function calls in 0.005 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.005    0.005 <string>:1(<module>)
        1    0.003    0.003    0.005    0.005 main.py:74(builtin_heapq_sort)
    10000    0.002    0.000    0.002    0.000 {built-in method _heapq.heappop}
    10000    0.001    0.000    0.001    0.000 {built-in method _heapq.heappush}
        1    0.000    0.000    0.005    0.005 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.len}
    10000    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

run() の第二引数に出力ファイル名を渡して SnakeViz などで可視化すると分かりやすい。

$ python -m pip install snakeviz
$ python -m snakeviz my_heap_sort.prof

これにより pop() 時の __down() で時間がかかっていることが分かった。

__right_child_idx() や __left_child_idx() を展開したり、len() の結果を変数に格納したりして関数の呼び出しを最小限にする。

def __down(self, idx):
        heap_len = len(self.heap)
        while idx < heap_len:
            minIdx = (idx * 2) + 1 if (idx * 2) + 1 < heap_len else - 1
            if minIdx == -1:
                return
            rightIdx = (idx * 2) + 2 if (idx * 2) + 2 < heap_len else - 1
            if rightIdx != -1 and \
                self.heap[rightIdx] < self.heap[minIdx]:
                    minIdx = rightIdx
            
            if self.heap[idx] < self.heap[minIdx]:
                return

            self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
            idx = minIdx

結果、0.255s → 0.0939s と 60% 以上短縮することができた。

次に line_profiler で __down() の行ごとの時間を取ってみたところ、swap しているところで時間がかかっていることが分かった。


$ pip install line_profiler
$ cat main.py
from line_profiler import profile

@profile
def __down(self, idx):
  ....

$ python -m kernprof -l main.py
$ python -m line_profiler -rmt "main.py.lprof"

Timer unit: 1e-06 s

Total time: 0.150715 s
File: main.py
Function: __down at line 48

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    48                                               @profile
    49                                               def __down(self, idx):
    50     10000       1330.0      0.1      0.9          heap_len = len(self.heap)
    51    116787      10922.0      0.1      7.2          while idx < heap_len:
    52    116786      22549.0      0.2     15.0              minIdx = (idx * 2) + 1 if (idx * 2) + 1 < heap_len else - 1
    53    116786      10685.0      0.1      7.1              if minIdx == -1:
    54      8470        694.0      0.1      0.5                  return
    55    108316      19697.0      0.2     13.1              rightIdx = (idx * 2) + 2 if (idx * 2) + 2 < heap_len else - 1
    56    108316       9848.0      0.1      6.5              if rightIdx != -1 and \
    57    108303      19366.0      0.2     12.8                  self.heap[rightIdx] < self.heap[minIdx]:
    58     53216       4188.0      0.1      2.8                      minIdx = rightIdx
    59                                                       
    60    108316      20421.0      0.2     13.5              if self.heap[idx] < self.heap[minIdx]:
    61      1529        153.0      0.1      0.1                  return
    62                                           
    63    106787      23134.0      0.2     15.3              self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
    64    106787       7728.0      0.1      5.1              idx = minIdx

  0.15 seconds - main.py:48 - __down

heapq の実装を見ると処理の最後に1回止まった位置に代入していて、実際そのようにするといくらか速くなった。

しかし、どう頑張っても heapq の 0.005 sec には届かなさそうだと思い、heapq のコードをよく見てみたところ可能なら C実装 が使われるようになっていて、この部分を消すと 0.03 sec程度になった。なるほど。