以下のようなツイートを見かけて、AtCoderでSQLが使えることがわかったので試しに AtCoder Beginners Selection の問題を可能な限りSQLを使って解いてみました。できるだけ標準入出力をPython側で行って、その他の計算をSQLite側で行います
昨日のD問題でPythonだと解けないと嘆いている皆さん、
— zat (@zat22859390) 2021年9月5日
Pythonの特権、SQLで解くという技もありますよ。
AtCoderでSQLが使えるのはPython/PyPyだけ!https://t.co/S3aCbGNhti
zenn.dev
PracticeA - Welcome to AtCoder
標準入出力から読み取った数値を足して、文字列はそのまま出力します。
SQLクエリを作るときに、変数を文字列結合でくっつけてくのは一般的にはあまりよくない気もしますが、今回の用途では特に問題ないと思うのでこの記事中では気にせず使っていきます。
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) a = input() b, c = input().split() s = input() for i in cur.execute(f""" SELECT {a} + {b} + {c}, "{s}" """): print(*i)
ABC086A - Product
2つの整数をかけた結果が偶数かどうかを答えます。
条件分岐は CASE式を使って CASE WHEN 条件 THEN 値 CASE 条件2 THEN 値2 略 ELSE その他の場合の値 END
のように書けます
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) a, b = map(int, input().split()) for i in cur.execute(f"SELECT CASE WHEN {a} * {b} % 2 = 0 THEN \"Even\" ELSE \"Odd\" END"): print(i[0])
ABC081A - Placing Marbles
0と1で構成された3文字の文字列に含まれる1の個数を求めれば良いです
いろいろとやりかたはあると思いますが、REPLACEで0を消して1だけになった文字列の長さを返しました
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) s = input() for i in cur.execute(f"SELECT LENGTH(REPLACE({s}, \"0\", \"\"))"): print(i[0])
ABC081B - Shift only
与えられた複数の整数をすべて割り切れる最大の \(2^k\) の \(k\) を回答します。
CASE式を使って下からi+1ビット目が1ならiを返してそうでないなら適当な大きな値を返す判定を作ります。
面倒なので各ビットについてPython側で生成するとこんな感じです。
cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)])
ある数について必要なすべてのビットに対して上の判定をして複数引数の場合のMIN関数で最小値を取れば、その数が割り切れる箇所がわかります。
Built-In Scalar SQL Functions
あとはその結果について集計関数の方のMIN関数ですべての数について集計しなおせばすべての整数が割り切れる値がわかります
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; CREATE TABLE num(n INTEGER); """) N = input() A = map(int, input().split()) cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A)) cases = ",\n".join([f"CASE WHEN (1 << {i}) & n THEN {i} ELSE 10000 END" for i in range(0, 30)]) for i in cur.execute(f""" SELECT MIN( MIN( {cases} ) ) FROM num """): print(i[0])
ABC087B - Coins
500円玉がA枚、100円玉がB枚、50円玉がC枚あるときに、合計X円にする方法を答えます
500円玉が0からA枚、100円玉が0からB枚、50円玉が0からC枚ある時の組み合わせの総当りを試して判定すればよいです。
SQLだとfor文がないので、0, 1, 2, ..., Aなどの数列を生成してからJOINして計算します。
SQLiteだと整数列の生成をしてくれるような関数がなさそうなので、WITH RECURSIVE
を使って整数列を生成します。
WITH RECURSIVE
は最初にデータを生成して、それ以降は生成されたデータに対してある処理をして、更に生成された同じ処理をして、というのを繰り返してデータを生成してくれます。
例えば以下の例だと、最初に0の行を生成して、それに+1して1の行が生成されて、1に対して+1して2の行が生成されて……という感じでデータを生成してくれます。WHEREの条件に引っかかって次の行が生成されなくなったら終了します。
WITH RECURSIVE coin500(i) AS ( SELECT 0 UNION SELECT i + 1 FROM coin500 WHERE i + 1 <= {A} )
全体のコード
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) A = input() B = input() C = input() X = input() for i in cur.execute(f""" WITH RECURSIVE coin500(i) AS ( SELECT 0 UNION SELECT i + 1 FROM coin500 WHERE i + 1 <= {A} ), coin100(i) AS ( SELECT 0 UNION SELECT i + 1 FROM coin100 WHERE i + 1 <= {B} ), coin50(i) AS ( SELECT 0 UNION SELECT i + 1 FROM coin50 WHERE i + 1 <= {C} ) SELECT COUNT(*) FROM coin500, coin100, coin50 WHERE coin500.i * 500 + coin100.i * 100 + coin50.i * 50 = {X} """): print(i[0])
ABC083B - Some Sums
1以上N以下の整数のうち10進法での各桁の数字の合計がA以上B以下である数の総和を答えます。
10000, 1000, 100, 10, 1で割った商を10で割った余りを求めれば各桁の数字が求められます。
あとは1からNまでの値を WITH RECURSIVE で生成して条件を満たす値の合計を計算するだけです。
%%bash echo ' import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) N, A, B = map(int, input().split()) for i in cur.execute(f""" WITH RECURSIVE seq(n) AS ( SELECT 1 UNION SELECT n + 1 FROM seq WHERE n + 1 <= {N} ) SELECT SUM(n) FROM seq WHERE CAST(n / 10000 AS INTEGER) % 10 + CAST(n / 1000 AS INTEGER) % 10 + CAST(n / 100 AS INTEGER) % 10 + CAST(n / 10 AS INTEGER) % 10 + CAST(n / 1 AS INTEGER) % 10 BETWEEN {A} AND {B} """): print(i[0]) ' > main.py echo "100 4 16" > input.txt cat input.txt | python main.py
ABC088B - Card Game for Two
与えられた整数を大きい順にソートして、最大のものから順に1番目、2番め…としたときに、奇数番目の合計から偶数番目の合計を引いて答えればよいです。
以下のようにウィンドウ関数が使えれば簡単に求められるのですがSQLiteのバージョンが古いのか、どこか書き間違えたのか使えなかったです。
SELECT SUM(CASE WHEN row_num % 2 = 1 THEN n ELSE 0 END) - SUM(CASE WHEN row_num % 2 = 0 THEN n ELSE 0 END) FROM ( SELECT n, ROW_NUMBER() OVER (ORDER BY n DESC) AS row_num FROM num )
なので一度テーブルに保存してから計算しました。SQLiteではデフォルトでは、テーブルに行を保存するとrowidという1から順のIDが暗黙的に振られるので、ソート済みの値をテーブルに保存してrowidの偶奇を見て計算します。
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; CREATE TABLE num(n INTEGER); """) N = input() A = map(int, input().split()) cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), A)) cur.execute(f""" CREATE TEMPORARY TABLE ordered AS SELECT n FROM num ORDER BY n DESC """) for i in cur.execute(f""" SELECT (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 1) - (SELECT SUM(n) FROM ordered WHERE rowid % 2 = 0) """): print(i[0])
ABC085B - Kagami Mochi
与えられた複数の整数のうちユニークな個数を答えればよいです。
SQLでは COUNT(DISTINCT) するだけなので簡単です
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; CREATE TABLE num(n INTEGER); """) N = int(input()) D = [] for i in range(N): D.append(input()) cur.executemany("INSERT INTO num VALUES(?)", map(lambda x: (x,), D)) for i in cur.execute(f""" SELECT COUNT(DISTINCT n) FROM num """): print(i[0])
ABC085C - Otoshidama
10000円札、5000円札、1000円札を合計N枚使って合計Y円になる組み合わせを一つ答えます。
10000円札、5000円札の枚数候補を生成してJOINして、残り金額を1000円札で補った場合に条件を満たすか判定。
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) N, Y = map(int, input().split()) result = list(cur.execute(f""" WITH RECURSIVE money10000(i) AS ( SELECT 0 UNION SELECT i + 1 FROM money10000 WHERE i + 1 <= {Y} / 10000 ), money5000(i) AS ( SELECT 0 UNION SELECT i + 1 FROM money5000 WHERE i + 1 <= {Y} / 5000 ) SELECT money10000.i, money5000.i, ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 FROM money10000, money5000 WHERE ({Y} - money10000.i * 10000 - money5000.i * 5000) % 1000 = 0 AND ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 >= 0 AND money10000.i + money5000.i + ({Y} - money10000.i * 10000 - money5000.i * 5000) / 1000 = {N} """)) if result: print(*result[0]) else: print(-1, -1, -1)
ABC049C - Daydream
与えられた文字列の末尾から貪欲に"dream", "dreamer", "erase", "eraser"の文字列を可能な限り取り除いていって、最終的に空文字列になるかどうかを判定します。
REGEXPを使って正規表現で判定すれば一瞬だと思ったのですが、SQLiteのバージョンが古くて使えなかったです。
WITH RECURSIVEを使って、末尾から文字列を取り除いた文字列を順に生成していくクエリを書いたのですが、これはTLEになってしまいました。
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) S = input() for i in cur.execute(f""" WITH RECURSIVE str(s) AS ( SELECT "{S}" UNION SELECT CASE WHEN substr(s, -5) = "dream" THEN substr(s, 1, length(s) - 5) WHEN substr(s, -7) = "dreamer" THEN substr(s, 1, length(s) - 7) WHEN substr(s, -5) = "erase" THEN substr(s, 1, length(s) - 5) WHEN substr(s, -6) = "eraser" THEN substr(s, 1, length(s) - 6) END FROM str WHERE s IS NOT NULL ) SELECT CASE WHEN COUNT(*) = 1 THEN "YES" ELSE "NO" END FROM str WHERE s = "" """): print(i[0])
最終手段としてSQLですべてやるのは諦めて正規表現はユーザー定義関数で判定することにしました。
SQLiteではPython側で定義した関数をSQLから呼び出すことができます。
以下のようにPythonの関数をcreate_functionで渡せばSQL中から使うことができます
def re_match(s: str): return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None con.create_function("RE_MATCH", 1, re_match) S = input() cur.execute(f""" SELECT CASE WHEN RE_MATCH("{S}") = 1 THEN "YES" ELSE "NO" END """)
全体のコード
import sqlite3 import re con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; """) # SQLiteのバージョンが新しければREGEXPで正規表現が使えるはずだけど使えないので妥協 def re_match(s: str): return re.match("^(dream|dreamer|erase|eraser)+$", s) is not None con.create_function("RE_MATCH", 1, re_match) S = input() for i in cur.execute(f""" SELECT CASE WHEN RE_MATCH("{S}") = 1 THEN "YES" ELSE "NO" END """): print(i[0])
ABC086C - Traveling
時刻が1進むごとにx軸方向かy軸方向に1移動する。
x, y = 0, 0からスタートするとして、時刻とそのときのx, y座標の組み合わせが複数与えられるので最後まで移動可能かを答える。
時刻の差がマンハッタン距離よりも大きいことと、時刻の差の偶奇とマンハッタン距離の偶奇が一致していればよい(同じ箇所で前後移動を繰り返せば2ずつ消費できる)
import sqlite3 con = sqlite3.connect(":memory:", isolation_level=None) cur = con.cursor() cur.executescript(""" PRAGMA trusted_schema = OFF; PRAGMA journal_mode = OFF; PRAGMA synchronous = OFF; PRAGMA temp_store = memory; PRAGMA secure_delete = OFF; CREATE TABLE points(t INTEGER, x INTEGER, y INTEGER); """) N = int(input()) data = [] for i in range(N): data.append(input().split()) cur.execute("INSERT INTO points VALUES(?, ?, ?)", (0, 0, 0)) cur.executemany("INSERT INTO points VALUES(?, ?, ?)", data) for i in cur.execute(f""" SELECT CASE WHEN ( SELECT COUNT(*) FROM points AS current JOIN points AS next ON current.rowid = next.rowid - 1 WHERE NOT ( ABS(current.x - next.x) + ABS(current.y - next.y) <= (next.t - current.t) AND (ABS(current.x - next.x) + ABS(current.y - next.y)) % 2 = (next.t - current.t) % 2 ) ) > 0 THEN "No" ELSE "Yes" END """): print(i[0])
まとめ
サクッとできるかと思いましたが結構つらかったです。やはりforループ的なものを気軽にかけないのはなかなか大変。
あとはSQLiteのバージョンが古くてウィンドウ関数や正規表現が使えなかったのも苦労しました。
SQLiteに触るのはほとんど初めてだったのですが構文の説明などのドキュメントが難しかったです
www.sqlite.org