唯物是真 @Scaled_Wurm

プログラミング(主にPython2.7)とか機械学習とか

最もシンプルで(驚くべき)ソートアルゴリズム?

こういうツイートを見かけたので元の論文を読んだり実装してみたりしました

注意点ですが、このアルゴリズム自体はよく使われるソートアルゴリズムと比べて特に利点があるわけではなくてネタ的な話です。

元論文
arxiv.org

実装(Python)

この記事のソートでは昇順に並べます

ICan’tBelieveItCanSort

というわけで論文で言うところのICan’tBelieveItCanSortを実装してみました。

f:id:sucrose:20211014200951p:plain
Algorithm 1 ICan’tBelieveItCanSort (論文より)
https://arxiv.org/abs/2110.01111

二重ループして\(j\)番目の要素が\(i\)番目の要素よりも大きかったら交換するだけなので単純です。
冒頭のツイートでも書かれていたように、以下で実装している他のソートとひと目見比べてみると交換するときの大小の条件が逆になっているように見えなくもないです。

def i_cant_believe_it_can_sort(A: list):
  n = len(A)
  for i in range(n):
    for j in range(n):
      if A[i] < A[j]:
        A[i], A[j] = A[j], A[i]
  return A

ExchangeSort

論文中にはExchangeSortというのも出てきたのでそれも実装してみました。
ちなみに元の論文では、ExchangeSortの内側の\(i\)からのループを先頭からのループと間違えて実装したときにICan’tBelieveItCanSort)を発見したというふうに書かれていました(その場合は大小の条件が逆なので降順になる)

def exchange_sort(A: list):
  n = len(A)
  for i in range(n):
    for j in range(i + 1, n):
      if A[i] > A[j]:
        A[i], A[j] = A[j], A[i]
  return A

バブルソート

参考用に同様にバブルソートを実装してみました。
隣同士の要素を比べて順番が降順になっていたら入れ替えます。
実装的にはICan’tBelieveItCanSortとどっちがシンプルかというのはなんとも言い難いです

def bubble_sort(A: list):
  n = len(A)
  for i in range(n - 1):
    for j in range(n - 1):
      if A[j] > A[j + 1]:
        A[j], A[j + 1] = A[j + 1], A[j]
  return A

なんでソートできるの?

ICan’tBelieveItCanSortがどう動くか適当な配列を例に説明してみましょう。

最初は[2, 5, 9, 7, 1, 7]とします。

\(j\)番目の要素が\(i\)番目の要素よりも大きかったら交換するので、最初の\(i=0\)のループ後は配列の先頭に最大の要素が入ります。
i=0, j=0, [2, 5, 9, 7, 1, 7]
i=0, j=1, [5, 2, 9, 7, 1, 7]
i=0, j=2, [9, 2, 5, 7, 1, 7]
i=0, j=3, [9, 2, 5, 7, 1, 7]
i=0, j=4, [9, 2, 5, 7, 1, 7]
i=0, j=5, [9, 2, 5, 7, 1, 7]
上の例を見てもわかるように\(i\)番目に最大の要素が来るとそれ以上交換は行われなくなります

\(i=1\)のループでは、最初が[9, 2, 5, 7, 1, 7]
i=1, j=0, [2, 9, 5, 7, 1, 7]
i=1, j=1, [2, 9, 5, 7, 1, 7]
以下略

\(i=0\)のループで先頭に最大の要素があるのでこれ以降\(j\)に関するループは\(j=i-1\)までやれば十分になっています(それ以上やっても交換は行われない)
このことから\( j < i \)となっています。なので最初に大小の比較の条件が逆に見える、という話がありましたが、i番目の方に大きい要素が来るのは配列の後ろの方に大きい要素が来るのと同じなので大小の条件は逆ではなさそうというのがわかります

\(i=2\)のループでは、最初が[2, 9, 5, 7, 1, 7]
i=2, j=0, [2, 9, 5, 7, 1, 7]
i=2, j=1, [2, 5, 9, 7, 1, 7]
以下略

\(i=3\)のループでは、最初が[2, 5, 9, 7, 1, 7]
i=3, j=0, [2, 5, 9, 7, 1, 7]
i=3, j=1, [2, 5, 9, 7, 1, 7]
i=3, j=2, [2, 5, 7, 9, 1, 7]
以下略

\(i=4\)のループでは、最初が[2, 5, 7, 9, 1, 7]
i=4, j=0, [1, 5, 7, 9, 2, 7]
i=4, j=1, [1, 2, 7, 9, 5, 7]
i=4, j=2, [1, 2, 5, 9, 7, 7]
i=4, j=3, [1, 2, 5, 7, 9, 7]
以下略

残りは同じ感じなので省略します。

上で試してみた結果のように、このアルゴリズムでは\(i\)のループが終わった時点では先頭から\(i\)番目までの要素について、次の要素の方が大きいもしくは等しい状態、つまり昇順の状態になります。
なので最後まで\(i\)のループを行うとすべての要素が昇順になることが言えます(論文には証明があります)

挙動を大雑把に理解するために\(i\)に関するそれぞれのループ後の配列をみると、\(i=1\)以降のループでは実質的には挿入ソートと同じ挙動になっていることがわかります。

i=0, [9, 2, 5, 7, 1, 7]
i=1, [2, 9, 5, 7, 1, 7]
i=2, [2, 5, 9, 7, 1, 7]
i=3, [2, 5, 7, 9, 1, 7]
i=4, [1, 2, 5, 7, 9, 7]
i=4, [1, 2, 5, 7, 7, 9]

挿入ソートでは\(i\)番目の要素をそれ以前の要素と見比べて適切な位置に挿入してそれ以降の要素を一つ後ろにずらします。
このアルゴリズムでは\(i\)番目の要素との交換を使って、挿入ソートでいうところの挿入して要素を一つずつ後ろにずらす操作を実現していることになります。
挿入ソート - Wikipedia

なのでコードからは想像しづらいですが\(i=0\)のループで最大値を先頭に持ってきて\(i=1\)以降のループでは挿入ソートのようなことをしているという挙動としては結構わかりやすいことをしていることがわかりました

コードを再掲

def i_cant_believe_it_can_sort(A: list):
  n = len(A)
  for i in range(n):
    for j in range(n):
      if A[i] < A[j]:
        A[i], A[j] = A[j], A[i]
  return A

おまけ

論文中にはIt is difficult to imagine that this algorithm was not discovered before,
but we are unable to find any references to it.
と書かれていますが、検索などを使って調べた人によると今回のソートアルゴリズムと同じものの実装例や質問が結構見つかるらしいです
Some links to discussions and/or accidental discoveries of this algorithm: OP (2... | Hacker News

AtCoder Beginners Selection を SQL (SQLite + Python) で解く

以下のようなツイートを見かけて、AtCoderでSQLが使えることがわかったので試しに AtCoder Beginners Selection の問題を可能な限りSQLを使って解いてみました。できるだけ標準入出力をPython側で行って、その他の計算をSQLite側で行います


zenn.dev

PracticeA - Welcome to AtCoder

標準入出力から読み取った数値を足して、文字列はそのまま出力します。
SQLクエリを作るときに、変数を文字列結合でくっつけてくのは一般的にはあまりよくない気もしますが、今回の用途では特に問題ないと思うのでこの記事中では気にせず使っていきます。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")

a = input()
b, c = input().split()
s = input()

for i in cur.execute(f"""
  SELECT {a} + {b} + {c}, "{s}"
"""):
  print(*i)

ABC086A - Product

2つの整数をかけた結果が偶数かどうかを答えます。
条件分岐は CASE式を使って CASE WHEN 条件 THEN 値 CASE 条件2 THEN 値2 略 ELSE その他の場合の値 ENDのように書けます

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



a, b = map(int, input().split())

for i in cur.execute(f"SELECT CASE WHEN {a} * {b} % 2 = 0 THEN \"Even\" ELSE \"Odd\" END"):
  print(i[0])

ABC081A - Placing Marbles

0と1で構成された3文字の文字列に含まれる1の個数を求めれば良いです
いろいろとやりかたはあると思いますが、REPLACEで0を消して1だけになった文字列の長さを返しました

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



s = input()

for i in cur.execute(f"SELECT LENGTH(REPLACE({s}, \"0\", \"\"))"):
  print(i[0])

ABC081B - Shift only

与えられた複数の整数をすべて割り切れる最大の \(2^k\) の \(k\) を回答します。
CASE式を使って下からi+1ビット目が1ならiを返してそうでないなら適当な大きな値を返す判定を作ります。
面倒なので各ビットについてPython側で生成するとこんな感じです。

cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)])

ある数について必要なすべてのビットに対して上の判定をして複数引数の場合のMIN関数で最小値を取れば、その数が割り切れる箇所がわかります。
Built-In Scalar SQL Functions
あとはその結果について集計関数の方のMIN関数ですべての数について集計しなおせばすべての整数が割り切れる値がわかります

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = input()
A = map(int, input().split())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A))

cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)])
for i in cur.execute(f"""
SELECT
  MIN(
    MIN(
      {cases}
    )
  )
FROM num
"""):
  print(i[0])

ABC087B - Coins

500円玉がA枚、100円玉がB枚、50円玉がC枚あるときに、合計X円にする方法を答えます
500円玉が0からA枚、100円玉が0からB枚、50円玉が0からC枚ある時の組み合わせの総当りを試して判定すればよいです。
SQLだとfor文がないので、0, 1, 2, ..., Aなどの数列を生成してからJOINして計算します。

SQLiteだと整数列の生成をしてくれるような関数がなさそうなので、WITH RECURSIVEを使って整数列を生成します。
WITH RECURSIVEは最初にデータを生成して、それ以降は生成されたデータに対してある処理をして、更に生成された同じ処理をして、というのを繰り返してデータを生成してくれます。
例えば以下の例だと、最初に0の行を生成して、それに+1して1の行が生成されて、1に対して+1して2の行が生成されて……という感じでデータを生成してくれます。WHEREの条件に引っかかって次の行が生成されなくなったら終了します。

WITH RECURSIVE coin500(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin500 WHERE i + 1 <= {A}
)

全体のコード

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



A = input()
B = input()
C = input()
X = input()

for i in cur.execute(f"""
WITH RECURSIVE coin500(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin500 WHERE i + 1 <= {A}
),  coin100(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin100 WHERE i + 1 <= {B}
),  coin50(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin50 WHERE i + 1 <= {C}
)

SELECT
  COUNT(*)
FROM coin500, coin100, coin50
WHERE coin500.i * 500 + coin100.i * 100 + coin50.i * 50 = {X}
"""):
  print(i[0])

ABC083B - Some Sums

1以上N以下の整数のうち10進法での各桁の数字の合計がA以上B以下である数の総和を答えます。
10000, 1000, 100, 10, 1で割った商を10で割った余りを求めれば各桁の数字が求められます。
あとは1からNまでの値を WITH RECURSIVE で生成して条件を満たす値の合計を計算するだけです。

%%bash
echo '
import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



N, A, B = map(int, input().split())

for i in cur.execute(f"""
WITH RECURSIVE seq(n) AS (
  SELECT 1
  UNION
  SELECT n + 1 FROM seq WHERE n + 1 <= {N}
)

SELECT
  SUM(n)
FROM seq
WHERE CAST(n / 10000 AS INTEGER) % 10 + CAST(n / 1000 AS INTEGER) % 10 + CAST(n / 100 AS INTEGER) % 10 + CAST(n / 10 AS INTEGER) % 10 + CAST(n / 1 AS INTEGER) % 10
  BETWEEN {A} AND {B}
"""):
  print(i[0])
' > main.py

echo "100 4 16" > input.txt
cat input.txt | python main.py

ABC088B - Card Game for Two

与えられた整数を大きい順にソートして、最大のものから順に1番目、2番め…としたときに、奇数番目の合計から偶数番目の合計を引いて答えればよいです。
以下のようにウィンドウ関数が使えれば簡単に求められるのですがSQLiteのバージョンが古いのか、どこか書き間違えたのか使えなかったです。

SELECT
  SUM(CASE WHEN row_num % 2 = 1 THEN n ELSE 0 END) - SUM(CASE WHEN row_num % 2 = 0 THEN n ELSE 0 END)
FROM (
  SELECT
    n, ROW_NUMBER() OVER (ORDER BY n DESC) AS row_num
  FROM num
)

なので一度テーブルに保存してから計算しました。SQLiteではデフォルトでは、テーブルに行を保存するとrowidという1から順のIDが暗黙的に振られるので、ソート済みの値をテーブルに保存してrowidの偶奇を見て計算します。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = input()
A = map(int, input().split())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A))

cur.execute(f"""
  CREATE TEMPORARY TABLE ordered AS
    SELECT
      n
    FROM num
    ORDER BY n DESC
""")

for i in cur.execute(f"""
  SELECT (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 1) - (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 0)
"""):
  print(i[0])

ABC085B - Kagami Mochi

与えられた複数の整数のうちユニークな個数を答えればよいです。
SQLでは COUNT(DISTINCT) するだけなので簡単です

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = int(input())
D = []
for i in range(N):
  D.append(input())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), D))

for i in cur.execute(f"""
  SELECT COUNT(DISTINCT n) FROM num
"""):
  print(i[0])

ABC085C - Otoshidama

10000円札、5000円札、1000円札を合計N枚使って合計Y円になる組み合わせを一つ答えます。
10000円札、5000円札の枚数候補を生成してJOINして、残り金額を1000円札で補った場合に条件を満たすか判定。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



N, Y = map(int, input().split())

result = list(cur.execute(f"""
WITH RECURSIVE money10000(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM money10000 WHERE i + 1 <= {Y} / 10000 
),  money5000(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM money5000 WHERE i + 1 <= {Y} / 5000
)

SELECT
  money10000.i, money5000.i, ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000
FROM money10000, money5000
WHERE ({Y} - money10000.i * 10000 - money5000.i * 5000) % 1000 = 0
AND ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 >= 0
AND money10000.i + money5000.i + ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 = {N}
"""))

if result:
  print(*result[0])
else:
  print(-1, -1, -1)

ABC049C - Daydream

与えられた文字列の末尾から貪欲に"dream", "dreamer", "erase", "eraser"の文字列を可能な限り取り除いていって、最終的に空文字列になるかどうかを判定します。
REGEXPを使って正規表現で判定すれば一瞬だと思ったのですが、SQLiteのバージョンが古くて使えなかったです。

WITH RECURSIVEを使って、末尾から文字列を取り除いた文字列を順に生成していくクエリを書いたのですが、これはTLEになってしまいました。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



S = input()

for i in cur.execute(f"""
  WITH RECURSIVE str(s) AS (
    SELECT "{S}"
    UNION
    SELECT
      CASE
        WHEN substr(s, -5) = "dream" THEN substr(s, 1, length(s) - 5)
        WHEN substr(s, -7) = "dreamer" THEN substr(s, 1, length(s) - 7)
        WHEN substr(s, -5) = "erase" THEN substr(s, 1, length(s) - 5)
        WHEN substr(s, -6) = "eraser" THEN substr(s, 1, length(s) - 6)
      END
    FROM str
    WHERE s IS NOT NULL
  )

  SELECT
    CASE
      WHEN COUNT(*) = 1 THEN "YES"
      ELSE "NO"
    END
  FROM str
  WHERE s = ""
"""):
  print(i[0])

最終手段としてSQLですべてやるのは諦めて正規表現はユーザー定義関数で判定することにしました。
SQLiteではPython側で定義した関数をSQLから呼び出すことができます。

以下のようにPythonの関数をcreate_functionで渡せばSQL中から使うことができます

def re_match(s: str):
  return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None

con.create_function("RE_MATCH", 1, re_match)

S = input()
cur.execute(f"""
  SELECT
    CASE
      WHEN RE_MATCH("{S}") = 1 THEN "YES"
      ELSE "NO"
    END
""")

全体のコード

import sqlite3
import re

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")

# SQLiteのバージョンが新しければREGEXPで正規表現が使えるはずだけど使えないので妥協
def re_match(s: str):
  return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None

con.create_function("RE_MATCH", 1, re_match)

S = input()
for i in cur.execute(f"""
  SELECT
    CASE
      WHEN RE_MATCH("{S}") = 1 THEN "YES"
      ELSE "NO"
    END
"""):
  print(i[0])

ABC086C - Traveling

時刻が1進むごとにx軸方向かy軸方向に1移動する。
x, y = 0, 0からスタートするとして、時刻とそのときのx, y座標の組み合わせが複数与えられるので最後まで移動可能かを答える。

時刻の差がマンハッタン距離よりも大きいことと、時刻の差の偶奇とマンハッタン距離の偶奇が一致していればよい(同じ箇所で前後移動を繰り返せば2ずつ消費できる)

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE points(t INTEGER, x INTEGER, y INTEGER);
""")

N = int(input())
data = []
for i in range(N):
  data.append(input().split())


cur.execute("INSERT INTO points VALUES(?, ?, ?)", (0, 0, 0))
cur.executemany("INSERT INTO points VALUES(?, ?, ?)", data)

for i in cur.execute(f"""
  SELECT
    CASE
      WHEN
        (
          SELECT
            COUNT(*)
          FROM points AS current
          JOIN points AS next
          ON current.rowid = next.rowid - 1
          WHERE
            NOT (
              ABS(current.x - next.x) + ABS(current.y - next.y) <= (next.t - current.t)
              AND (ABS(current.x - next.x) + ABS(current.y - next.y)) % 2 = (next.t - current.t) % 2
            )
        ) > 0
        THEN "No"
      ELSE "Yes"
    END
"""):
  print(i[0])

まとめ

サクッとできるかと思いましたが結構つらかったです。やはりforループ的なものを気軽にかけないのはなかなか大変。
あとはSQLiteのバージョンが古くてウィンドウ関数や正規表現が使えなかったのも苦労しました。

SQLiteに触るのはほとんど初めてだったのですが構文の説明などのドキュメントが難しかったです
www.sqlite.org
SQLiteのロゴ