唯物是真 @Scaled_Wurm

プログラミング(主にPython2.7)とか機械学習とか

AtCoder Beginners Selection を SQL (SQLite + Python) で解く

以下のようなツイートを見かけて、AtCoderでSQLが使えることがわかったので試しに AtCoder Beginners Selection の問題を可能な限りSQLを使って解いてみました。できるだけ標準入出力をPython側で行って、その他の計算をSQLite側で行います


zenn.dev

PracticeA - Welcome to AtCoder

標準入出力から読み取った数値を足して、文字列はそのまま出力します。
SQLクエリを作るときに、変数を文字列結合でくっつけてくのは一般的にはあまりよくない気もしますが、今回の用途では特に問題ないと思うのでこの記事中では気にせず使っていきます。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")

a = input()
b, c = input().split()
s = input()

for i in cur.execute(f"""
  SELECT {a} + {b} + {c}, "{s}"
"""):
  print(*i)

ABC086A - Product

2つの整数をかけた結果が偶数かどうかを答えます。
条件分岐は CASE式を使って CASE WHEN 条件 THEN 値 CASE 条件2 THEN 値2 略 ELSE その他の場合の値 ENDのように書けます

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



a, b = map(int, input().split())

for i in cur.execute(f"SELECT CASE WHEN {a} * {b} % 2 = 0 THEN \"Even\" ELSE \"Odd\" END"):
  print(i[0])

ABC081A - Placing Marbles

0と1で構成された3文字の文字列に含まれる1の個数を求めれば良いです
いろいろとやりかたはあると思いますが、REPLACEで0を消して1だけになった文字列の長さを返しました

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



s = input()

for i in cur.execute(f"SELECT LENGTH(REPLACE({s}, \"0\", \"\"))"):
  print(i[0])

ABC081B - Shift only

与えられた複数の整数をすべて割り切れる最大の \(2^k\) の \(k\) を回答します。
CASE式を使って下からi+1ビット目が1ならiを返してそうでないなら適当な大きな値を返す判定を作ります。
面倒なので各ビットについてPython側で生成するとこんな感じです。

cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)])

ある数について必要なすべてのビットに対して上の判定をして複数引数の場合のMIN関数で最小値を取れば、その数が割り切れる箇所がわかります。
Built-In Scalar SQL Functions
あとはその結果について集計関数の方のMIN関数ですべての数について集計しなおせばすべての整数が割り切れる値がわかります

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = input()
A = map(int, input().split())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A))

cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)])
for i in cur.execute(f"""
SELECT
  MIN(
    MIN(
      {cases}
    )
  )
FROM num
"""):
  print(i[0])

ABC087B - Coins

500円玉がA枚、100円玉がB枚、50円玉がC枚あるときに、合計X円にする方法を答えます
500円玉が0からA枚、100円玉が0からB枚、50円玉が0からC枚ある時の組み合わせの総当りを試して判定すればよいです。
SQLだとfor文がないので、0, 1, 2, ..., Aなどの数列を生成してからJOINして計算します。

SQLiteだと整数列の生成をしてくれるような関数がなさそうなので、WITH RECURSIVEを使って整数列を生成します。
WITH RECURSIVEは最初にデータを生成して、それ以降は生成されたデータに対してある処理をして、更に生成された同じ処理をして、というのを繰り返してデータを生成してくれます。
例えば以下の例だと、最初に0の行を生成して、それに+1して1の行が生成されて、1に対して+1して2の行が生成されて……という感じでデータを生成してくれます。WHEREの条件に引っかかって次の行が生成されなくなったら終了します。

WITH RECURSIVE coin500(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin500 WHERE i + 1 <= {A}
)

全体のコード

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



A = input()
B = input()
C = input()
X = input()

for i in cur.execute(f"""
WITH RECURSIVE coin500(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin500 WHERE i + 1 <= {A}
),  coin100(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin100 WHERE i + 1 <= {B}
),  coin50(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM coin50 WHERE i + 1 <= {C}
)

SELECT
  COUNT(*)
FROM coin500, coin100, coin50
WHERE coin500.i * 500 + coin100.i * 100 + coin50.i * 50 = {X}
"""):
  print(i[0])

ABC083B - Some Sums

1以上N以下の整数のうち10進法での各桁の数字の合計がA以上B以下である数の総和を答えます。
10000, 1000, 100, 10, 1で割った商を10で割った余りを求めれば各桁の数字が求められます。
あとは1からNまでの値を WITH RECURSIVE で生成して条件を満たす値の合計を計算するだけです。

%%bash
echo '
import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



N, A, B = map(int, input().split())

for i in cur.execute(f"""
WITH RECURSIVE seq(n) AS (
  SELECT 1
  UNION
  SELECT n + 1 FROM seq WHERE n + 1 <= {N}
)

SELECT
  SUM(n)
FROM seq
WHERE CAST(n / 10000 AS INTEGER) % 10 + CAST(n / 1000 AS INTEGER) % 10 + CAST(n / 100 AS INTEGER) % 10 + CAST(n / 10 AS INTEGER) % 10 + CAST(n / 1 AS INTEGER) % 10
  BETWEEN {A} AND {B}
"""):
  print(i[0])
' > main.py

echo "100 4 16" > input.txt
cat input.txt | python main.py

ABC088B - Card Game for Two

与えられた整数を大きい順にソートして、最大のものから順に1番目、2番め…としたときに、奇数番目の合計から偶数番目の合計を引いて答えればよいです。
以下のようにウィンドウ関数が使えれば簡単に求められるのですがSQLiteのバージョンが古いのか、どこか書き間違えたのか使えなかったです。

SELECT
  SUM(CASE WHEN row_num % 2 = 1 THEN n ELSE 0 END) - SUM(CASE WHEN row_num % 2 = 0 THEN n ELSE 0 END)
FROM (
  SELECT
    n, ROW_NUMBER() OVER (ORDER BY n DESC) AS row_num
  FROM num
)

なので一度テーブルに保存してから計算しました。SQLiteではデフォルトでは、テーブルに行を保存するとrowidという1から順のIDが暗黙的に振られるので、ソート済みの値をテーブルに保存してrowidの偶奇を見て計算します。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = input()
A = map(int, input().split())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A))

cur.execute(f"""
  CREATE TEMPORARY TABLE ordered AS
    SELECT
      n
    FROM num
    ORDER BY n DESC
""")

for i in cur.execute(f"""
  SELECT (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 1) - (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 0)
"""):
  print(i[0])

ABC085B - Kagami Mochi

与えられた複数の整数のうちユニークな個数を答えればよいです。
SQLでは COUNT(DISTINCT) するだけなので簡単です

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE num(n INTEGER);
""")



N = int(input())
D = []
for i in range(N):
  D.append(input())

cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), D))

for i in cur.execute(f"""
  SELECT COUNT(DISTINCT n) FROM num
"""):
  print(i[0])

ABC085C - Otoshidama

10000円札、5000円札、1000円札を合計N枚使って合計Y円になる組み合わせを一つ答えます。
10000円札、5000円札の枚数候補を生成してJOINして、残り金額を1000円札で補った場合に条件を満たすか判定。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



N, Y = map(int, input().split())

result = list(cur.execute(f"""
WITH RECURSIVE money10000(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM money10000 WHERE i + 1 <= {Y} / 10000 
),  money5000(i) AS (
  SELECT 0
  UNION
  SELECT i + 1 FROM money5000 WHERE i + 1 <= {Y} / 5000
)

SELECT
  money10000.i, money5000.i, ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000
FROM money10000, money5000
WHERE ({Y} - money10000.i * 10000 - money5000.i * 5000) % 1000 = 0
AND ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 >= 0
AND money10000.i + money5000.i + ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 = {N}
"""))

if result:
  print(*result[0])
else:
  print(-1, -1, -1)

ABC049C - Daydream

与えられた文字列の末尾から貪欲に"dream", "dreamer", "erase", "eraser"の文字列を可能な限り取り除いていって、最終的に空文字列になるかどうかを判定します。
REGEXPを使って正規表現で判定すれば一瞬だと思ったのですが、SQLiteのバージョンが古くて使えなかったです。

WITH RECURSIVEを使って、末尾から文字列を取り除いた文字列を順に生成していくクエリを書いたのですが、これはTLEになってしまいました。

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")



S = input()

for i in cur.execute(f"""
  WITH RECURSIVE str(s) AS (
    SELECT "{S}"
    UNION
    SELECT
      CASE
        WHEN substr(s, -5) = "dream" THEN substr(s, 1, length(s) - 5)
        WHEN substr(s, -7) = "dreamer" THEN substr(s, 1, length(s) - 7)
        WHEN substr(s, -5) = "erase" THEN substr(s, 1, length(s) - 5)
        WHEN substr(s, -6) = "eraser" THEN substr(s, 1, length(s) - 6)
      END
    FROM str
    WHERE s IS NOT NULL
  )

  SELECT
    CASE
      WHEN COUNT(*) = 1 THEN "YES"
      ELSE "NO"
    END
  FROM str
  WHERE s = ""
"""):
  print(i[0])

最終手段としてSQLですべてやるのは諦めて正規表現はユーザー定義関数で判定することにしました。
SQLiteではPython側で定義した関数をSQLから呼び出すことができます。

以下のようにPythonの関数をcreate_functionで渡せばSQL中から使うことができます

def re_match(s: str):
  return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None

con.create_function("RE_MATCH", 1, re_match)

S = input()
cur.execute(f"""
  SELECT
    CASE
      WHEN RE_MATCH("{S}") = 1 THEN "YES"
      ELSE "NO"
    END
""")

全体のコード

import sqlite3
import re

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;
""")

# SQLiteのバージョンが新しければREGEXPで正規表現が使えるはずだけど使えないので妥協
def re_match(s: str):
  return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None

con.create_function("RE_MATCH", 1, re_match)

S = input()
for i in cur.execute(f"""
  SELECT
    CASE
      WHEN RE_MATCH("{S}") = 1 THEN "YES"
      ELSE "NO"
    END
"""):
  print(i[0])

ABC086C - Traveling

時刻が1進むごとにx軸方向かy軸方向に1移動する。
x, y = 0, 0からスタートするとして、時刻とそのときのx, y座標の組み合わせが複数与えられるので最後まで移動可能かを答える。

時刻の差がマンハッタン距離よりも大きいことと、時刻の差の偶奇とマンハッタン距離の偶奇が一致していればよい(同じ箇所で前後移動を繰り返せば2ずつ消費できる)

import sqlite3

con = sqlite3.connect(":memory:", isolation_level=None)
cur = con.cursor()
cur.executescript("""
  PRAGMA trusted_schema = OFF;
  PRAGMA journal_mode = OFF;
  PRAGMA synchronous = OFF;
  PRAGMA temp_store = memory;
  PRAGMA secure_delete = OFF;

  CREATE TABLE points(t INTEGER, x INTEGER, y INTEGER);
""")

N = int(input())
data = []
for i in range(N):
  data.append(input().split())


cur.execute("INSERT INTO points VALUES(?, ?, ?)", (0, 0, 0))
cur.executemany("INSERT INTO points VALUES(?, ?, ?)", data)

for i in cur.execute(f"""
  SELECT
    CASE
      WHEN
        (
          SELECT
            COUNT(*)
          FROM points AS current
          JOIN points AS next
          ON current.rowid = next.rowid - 1
          WHERE
            NOT (
              ABS(current.x - next.x) + ABS(current.y - next.y) <= (next.t - current.t)
              AND (ABS(current.x - next.x) + ABS(current.y - next.y)) % 2 = (next.t - current.t) % 2
            )
        ) > 0
        THEN "No"
      ELSE "Yes"
    END
"""):
  print(i[0])

まとめ

サクッとできるかと思いましたが結構つらかったです。やはりforループ的なものを気軽にかけないのはなかなか大変。
あとはSQLiteのバージョンが古くてウィンドウ関数や正規表現が使えなかったのも苦労しました。

SQLiteに触るのはほとんど初めてだったのですが構文の説明などのドキュメントが難しかったです
www.sqlite.org
SQLiteのロゴ