唯物是真 @Scaled_Wurm

プログラミング(主にPython2.7)とか機械学習とか

PythonでpaizaオンラインハッカソンVol.1に挑戦した #paizahack_01

一応解けたけど、Twitterを見てるとPythonでテストケース3を0.3秒台とかで解いている人がいて、どんな解き方をしているのか気になります

問題設定

\(N\)個の商品の値段の要素を持つ配列中から2つの要素を選んで和を計算する
\(D\)個の上限の値それぞれについて、その上限以下でできるだけ大きい和を答える

解法

\(1 \leq N \leq 200000\)と\(N\)が非常に大きいので、値段の配列中のすべてのペアについて和を計算する\(O(N^2)\)のアルゴリズムだと遅すぎる
またPythonだと2分探索を使った\(O(ND \log N)\)のアルゴリズムでも結構時間がかかる

以下\(O(ND + N \log N)\)のアルゴリズムによる結果

結果:
 テストケース1:success 0.09秒
 テストケース2:success 0.32秒
 テストケース3:success 3.03秒

mugenenさんの採点結果[100点 CTOに昇進しました!]|paizaオンラインハッカソンVol.1

まず値段の配列をソートしてから考える。
現在注目している値段の和(上限のギリギリ下)をheadとtailの2つの添字の要素の和とすると、tailを1減らした時には和が減少するのでheadは現在のheadと同じかそれより大きい添字になる
上限それぞれについて、headとtailの2つの添字の要素の和と上限との大小関係を見て、しゃくとり法っぽく(?)headとtailを動かしながら最大を計算していく
具体的にはheadが先頭の要素、tailが末尾の要素を指している状態からスタートして次のループを行えばよい

  • 現在の和が上限よりも多ければtailを減らして、上限以下ならばheadを増やしていく。

これによって和がちょうど上限の境界の上下になるペアを順番に辿って行くことができるので最大が求められる
\(D\)個の上限について、headとtailはたかだか値段の配列の要素数しか変更されないので\(O(ND)\)、事前に値段をソートしているのでまとめると\(O(ND + N \log N)\)

以下のソースコードでは最初に全部の入力を読み込むのと、最初のtailを求めるのに二分探索を使うのと、配列から重複要素を取り除くのと3つの方法で定数倍の高速化をしている

import bisect
import os

stdin = iter(os.read(0, 10000000).splitlines())


N, D = map(int, next(stdin).split())

p = [int(next(stdin)) for i in xrange(N)]

p.sort()

uniq = list(set(p))
uniq.sort()

for i in xrange(D):
    m = int(next(stdin))
    result = 0

    head = 0
    tail = bisect.bisect(uniq, m - uniq[0]) - 1
    while head < tail:
        s = uniq[head] + uniq[tail]
        if s > m:
            tail -= 1
        else:
            head += 1
            result = max(result, s)
    idx = bisect.bisect(p, m / 2) - 1
    if idx > 0:
        result = max(result, p[idx] + p[idx - 1])
    print result

追記(2013-12-07)

上限に一致したら終了するという処理を入れただけでだいぶ速くなった

結果:
 テストケース1:success 0.09秒
 テストケース2:success 0.10秒
 テストケース3:success 0.50秒

mugenenさんの採点結果[100点 CTOに昇進しました!]|paizaオンラインハッカソンVol.1
import bisect
import os

stdin = iter(os.read(0, 10000000).splitlines())


N, D = map(int, next(stdin).split())

p = [int(next(stdin)) for i in xrange(N)]

p.sort()

uniq = list(set(p))
uniq.sort()

for i in xrange(D):
    m = int(next(stdin))
    result = 0

    head = 0
    tail = bisect.bisect(uniq, m - uniq[0]) - 1
    while head < tail:
        s = uniq[head] + uniq[tail]
        if s == m:
            result = s
            break
        elif s > m:
            tail -= 1
        else:
            head += 1
            result = max(result, s)
    else:
        idx = bisect.bisect(p, m / 2) - 1
        if idx > 0:
            result = max(result, p[idx] + p[idx - 1])
    print result

見つけた解説記事

以下のブログ記事辺りを参考にすると、ソートとか入出力とかでもっと工夫が必要そうですね