Python の cProfile や line_profile でプロファイルを取って遅いコードを改善する
pythonPython でヒープソートを実装したのだが、ビルトインの heapq を用いたものと比べて遥かに遅い。
Python の組み込みコンテナ、collections.deque と heapq - sambaiz-net
class Heap:
def __init__(self, values, compare):
self.heap = values
self.compare = compare
for i in reversed(range(len(self.heap) // 2)):
self.__down(i)
def push(self, value):
self.heap.append(value)
self.__up(len(self.heap) - 1)
def delete(self, idx):
if idx >= len(self.heap):
return
self.heap[idx], self.heap[len(self.heap) - 1] = \
self.heap[len(self.heap) - 1], self.heap[idx]
self.heap.pop()
self.__down(idx)
self.__up(idx)
def pop(self):
if len(self.heap) == 0:
return None
ret = self.heap[0]
self.delete(0)
return ret
def __parent_idx(self, idx):
return (idx - 1) // 2 if idx > 0 else - 1
def __left_child_idx(self, idx):
return (idx * 2) + 1 if (idx * 2) + 1 < len(self.heap) else - 1
def __right_child_idx(self, idx):
return (idx * 2) + 2 if (idx * 2) + 2 < len(self.heap) else - 1
def __up(self, idx):
while idx > 0:
if self.compare(self.heap[self.__parent_idx(idx)], self.heap[idx]):
return
self.heap[idx], self.heap[self.__parent_idx(idx)] = \
self.heap[self.__parent_idx(idx)], self.heap[idx]
idx = self.__parent_idx(idx)
def __down(self, idx):
while idx < len(self.heap):
minIdx = self.__left_child_idx(idx)
if minIdx == -1:
return
if self.__right_child_idx(idx) != -1 and \
self.compare(
self.heap[self.__right_child_idx(idx)],
self.heap[minIdx]
):
minIdx = self.__right_child_idx(idx)
if self.compare(self.heap[idx], self.heap[minIdx]):
return
self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
idx = minIdx
def my_heap_sort(data):
minheap = Heap([], lambda a, b: a < b)
for d in data:
minheap.push(d)
ret = []
for i in range(len(data)):
ret.append(minheap.pop())
return ret
def builtin_heapq_sort(data):
heap = []
for d in data:
heapq.heappush(heap, d)
ret = []
for i in range(len(heap)):
ret.append(heapq.heappop(heap))
return ret
まずは cProifile でプロファイルを取ってみる。
import cProfile
import random
...
if __name__ == "__main__":
x = list(range(10000))
random.Random(5).shuffle(x)
assert my_heap_sort(x) == builtin_heapq_sort(x)
cProfile.run('my_heap_sort(x)')
cProfile.run('builtin_heapq_sort(x)')
次のような出力が得られた。
$ python --version
Python 3.11.4
$ python main.py
1330480 function calls in 0.258 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.258 0.258 <string>:1(<module>)
10000 0.009 0.000 0.227 0.000 main.py:12(delete)
1 0.000 0.000 0.000 0.000 main.py:2(__init__)
10000 0.003 0.000 0.230 0.000 main.py:21(pop)
61063 0.005 0.000 0.005 0.000 main.py:28(__parent_idx)
116786 0.022 0.000 0.026 0.000 main.py:31(__left_child_idx)
269835 0.050 0.000 0.059 0.000 main.py:34(__right_child_idx)
20000 0.013 0.000 0.020 0.000 main.py:37(__up)
10000 0.115 0.000 0.216 0.000 main.py:45(__down)
1 0.004 0.004 0.258 0.258 main.py:63(my_heap_sort)
239381 0.013 0.000 0.013 0.000 main.py:64(<lambda>)
10000 0.004 0.000 0.024 0.000 main.py:8(push)
1 0.000 0.000 0.258 0.258 {built-in method builtins.exec}
553410 0.019 0.000 0.019 0.000 {built-in method builtins.len}
20000 0.001 0.000 0.001 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
10000 0.000 0.000 0.000 0.000 {method 'pop' of 'list' objects}
30005 function calls in 0.005 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.005 0.005 <string>:1(<module>)
1 0.003 0.003 0.005 0.005 main.py:74(builtin_heapq_sort)
10000 0.002 0.000 0.002 0.000 {built-in method _heapq.heappop}
10000 0.001 0.000 0.001 0.000 {built-in method _heapq.heappush}
1 0.000 0.000 0.005 0.005 {built-in method builtins.exec}
1 0.000 0.000 0.000 0.000 {built-in method builtins.len}
10000 0.000 0.000 0.000 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
run() の第二引数に出力ファイル名を渡して SnakeViz などで可視化すると分かりやすい。
$ python -m pip install snakeviz
$ python -m snakeviz my_heap_sort.prof
これにより pop() 時の __down() で時間がかかっていることが分かった。
__right_child_idx() や __left_child_idx() を展開したり、len() の結果を変数に格納したりして関数の呼び出しを最小限にする。
def __down(self, idx):
heap_len = len(self.heap)
while idx < heap_len:
minIdx = (idx * 2) + 1 if (idx * 2) + 1 < heap_len else - 1
if minIdx == -1:
return
rightIdx = (idx * 2) + 2 if (idx * 2) + 2 < heap_len else - 1
if rightIdx != -1 and \
self.heap[rightIdx] < self.heap[minIdx]:
minIdx = rightIdx
if self.heap[idx] < self.heap[minIdx]:
return
self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
idx = minIdx
結果、0.255s → 0.0939s と 60% 以上短縮することができた。
次に line_profiler で __down() の行ごとの時間を取ってみたところ、swap しているところで時間がかかっていることが分かった。
$ pip install line_profiler
$ cat main.py
from line_profiler import profile
@profile
def __down(self, idx):
....
$ python -m kernprof -l main.py
$ python -m line_profiler -rmt "main.py.lprof"
Timer unit: 1e-06 s
Total time: 0.150715 s
File: main.py
Function: __down at line 48
Line # Hits Time Per Hit % Time Line Contents
==============================================================
48 @profile
49 def __down(self, idx):
50 10000 1330.0 0.1 0.9 heap_len = len(self.heap)
51 116787 10922.0 0.1 7.2 while idx < heap_len:
52 116786 22549.0 0.2 15.0 minIdx = (idx * 2) + 1 if (idx * 2) + 1 < heap_len else - 1
53 116786 10685.0 0.1 7.1 if minIdx == -1:
54 8470 694.0 0.1 0.5 return
55 108316 19697.0 0.2 13.1 rightIdx = (idx * 2) + 2 if (idx * 2) + 2 < heap_len else - 1
56 108316 9848.0 0.1 6.5 if rightIdx != -1 and \
57 108303 19366.0 0.2 12.8 self.heap[rightIdx] < self.heap[minIdx]:
58 53216 4188.0 0.1 2.8 minIdx = rightIdx
59
60 108316 20421.0 0.2 13.5 if self.heap[idx] < self.heap[minIdx]:
61 1529 153.0 0.1 0.1 return
62
63 106787 23134.0 0.2 15.3 self.heap[idx], self.heap[minIdx] = self.heap[minIdx], self.heap[idx]
64 106787 7728.0 0.1 5.1 idx = minIdx
0.15 seconds - main.py:48 - __down
heapq の実装を見ると処理の最後に1回止まった位置に代入していて、実際そのようにするといくらか速くなった。
しかし、どう頑張っても heapq の 0.005 sec には届かなさそうだと思い、heapq のコードをよく見てみたところ可能なら C実装 が使われるようになっていて、この部分を消すと 0.03 sec程度になった。なるほど。