Educational DP Contest
Educational DP Contest - AtCoder
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.
問題
B - Frog 2
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.
方針
Frog 1の強化版。計算量は\(O(NK)\)かかる。
解答
#input
n, k = map(int, input().split())
h = list(map(int, input().split()))
#pypy3では通るがpython3ではTLE
#output
dp = [0]*n
dp[1] = abs(h[1]-h[0])
for i in range(2, n):
temp = abs(h[i]-h[i-1]) + dp[i-1]
for j in range(2, k+1):
if 0 <= i-j <= n-1:
temp = min(temp, dp[i-j]+abs(h[i]-h[i-j]))
dp[i] = temp
print(dp[-1])
愚直にやると上のようになる。これはPython3ではTLEになる(PyPy3では通る)。
numpyを用いるとPython3でも通る。
#version 2
#input
n, k = map(int, input().split())
h = np.array(list(map(int, input().split())))
#output
import numpy as np
dp = np.zeros(n, np.int64)
dp[1] = abs(h[1]-h[0])
for i in range(1, n):
temp = abs(h[i-1]-h[i]) + dp[i-1]
low_lim = max(0, i-k)
temp2 = abs(h[i]-h[low_lim:i])
temp3 = min(temp2 + dp[low_lim:i])
temp = min(temp, temp3)
dp[i] = temp
print(dp[-1])
それほど早くはない。これはせっかくnumpyの配列を用いているのに、一つ一つの要素を比べているからで、以下のようにすると少し早くなる。
#atcoder template
def main():
import sys
input = sys.stdin.readline
#文字列入力の時は上記はerrorとなる。
#ここにコード
#input
import numpy as np
n, k = map(int, input().split())
h = np.array(list(map(int, input().split())))
#output
dp = np.zeros(n, np.int64)
dp[1] = abs(h[1]-h[0])
for i in range(1, n):
low_lim = max(0, i-k)
dp[i] = np.amin(np.abs(h[i]-h[low_lim:i])+dp[low_lim:i])
print(dp[-1])
#N = 1のときなどcorner caseを確認!
if __name__ == "__main__":
main()
それでも決して早くない。np.minimumを用いたのが以下。numbaも入れている。
import sys
import numpy as np
def solve(s):
#s = input()
#sl = list(input().split())
n, k = s[:2]
h = s[2:]
#ここにコード
dp = np.full(n, 10**10, np.int64)
dp[0] = 0
for i in range(n):
dp[i:i+k+1] = np.minimum(dp[i:i+k+1], dp[i]+np.abs(h[i:i+k+1]-h[i]))
#N = 1のときなどcorner caseを確認!
return dp[-1]#answer
def main():
stdin = np.fromstring(open(0).read(), dtype=np.int64, sep=' ')
# stdin = np.fromstring(IN, dtype=np.int64, sep=' ')
print(solve(stdin))
def cc_export():
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', '(i8[:],)')(solve)
cc.compile()
if __name__ == "__main__":
if sys.argv[-1] == 'ONLINE_JUDGE':
cc_export()
sys.exit(0)
from my_module import solve # type: ignore
main()
提出結果
Submission #31632219 - Educational DP Contest
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.
感想
numpy.minimumを使いこなしたい。
コメント