clojureの再帰とスタックオーバーフロー(remove-first)

たまたまcljureのMLを見ていた時に、nickiktさんという人が

「"scheme-remove-first"という関数を
書いたのだけれど、自分の実装ではstack overflowしてしまうので気に入らない。
誰かだれかもっと良いバージョンを書いてみてくれ

https://groups.google.com/group/clojure/browse_thread/thread/9e400d1381b11665?pli=1

というようなことを書いていたメールを発見しました。
scheme-remove-firstでやりたいことはこんな感じです。

(remove-first 2 '(1 2 3 1 2 3 1 2 3)) ; => (1 3 1 2 3 1 2 3)

名前の通りですね。

実際にnickiktさんが書いた関数はこんな感じでした。

remove-first

(defn remove-first [syb lst]
 (if (empty? lst)
   '()
   (if (= (first lst) syb)
     (rest lst)
     (cons (first lst) (remove-first syb (rest lst))))))

これは明らかによくありませんね。
consとrestでschemeのように再帰させていくかたちですが、スタックを食いつぶしてしまいます。
とりあえずは実装の方法を変えないままにコードを見易くしていきましょう。

clojureのdestructuring(分配束縛)を使って一気に複数の値を束縛できます。これを使いましょう。
ついでに、eとcollという風に引数の名前も変えました。

destructuringを使う。

(let [[x & xs] [1 2 3]]
     (list x xs)) ; => (1 (2 3))

(defn remove-first2 [e coll] 
 (if (empty? coll)
   '()
   (let [[x & coll*] coll]
   (if (= e x) 
     coll*
     (cons x (remove-first2 e coll*))))))

実はclojureにはリストの残りの値を取り出す関数が2つあります。nextとrestです。
違いは、残りの要素が空の場合の結果です。nextはnilを返してrestは空リストを返します。

(next [1]) ; => nil
(rest [1]) ; => ()

もちろん、nilは偽としてあつかわれますね?
ついでに、when-letマクロも使ってみましょう。
これはletによる値の束縛と条件分岐を同時に行なうマクロです。

(when-let [[x & xs] nil]
  (println "success!"))
;; nil

このwhen-letとdestructuringを組み合わせることでもっと簡単に書けるようになります。
(もちろん、今のままではstack overflowすることには変わりません)

when-letを使う

(defn remove-first3 [e coll] 
  (when-let [[x & coll*] coll]
    (if (= e x) coll* (cons x (remove-first3 e coll*)))))

記述がコンパクトになりました。ifの部分を1行にしてしまっても問題ないですね。
わかりやすい記述ができたところで、そろそろ本題のstack overflowをなくしていきましょう。
遅延シークエンスを作るようにして途中で計算を打ちきれるようにします。

stack overflowをなくす

遅延シークエンスを作るにはlazy-seqを使いましょう。
pythonjavascriptを使っている人にとってのyieldみたいな感じです。

(defn remove-first* [e coll]
  (lazy-seq
   (when-let [[x & xs] coll]
     (if (= e x) xs (cons x (remove-first* e xs))))))

たしかにきれいに短く書けたのですが、clojureの文法を使って短くできただけで
結局のところremove-firstの処理を1から全て実装してしまっています。もっと楽はできないでしょうか?

clojureの標準ライブラリに詳しい人は、「そんなことしなくてもclojure.core/split-withを使えばもっと簡単に書けるよ」
と言うかもしれません。こんな感じですね。

(defn remove-first+ [e coll]
  (let [[left right] (split-with #(not (= e %)) coll)]
    (lazy-cat left (next right))))

確かにsplit-withを使うことで簡単に書けます。しかし問題があります。遅いんです。

(let [xs (range 1 100000)]
  ;;(time (dotimes [i 1000] (remove-first3 99999 xs))) ;;もちろんStackOverflow
    (time (dotimes [i 1000] (remove-first+ 99999 xs)))
    (time (dotimes [i 1000] (remove-first* 99999 xs))))

;; "Elapsed time: 4.098972 msecs"
;; "Elapsed time: 0.818267 msecs"

自分の手で書いた関数に比べてだいぶ遅いですね。
安易なsplit-withの利用は避けるべきです。
どうにかできないでしょうか?

split-withが遅延シークエンスを返さないから問題なのです。
split-withに近い手続きが書けたなら、
split-withを使った時のように関数を使って楽に書くことができそうです。

結局のところ2つに分割した内の残りの部分をひとつ取り除ければ良いのです。
split-withのように条件式をとって、それに対応しなくなったところで止まる手続きを考えましょう。
CPS風に、条件に合致しなくなったところで呼ぶ関数contも一緒にとってあげることにします。

split-with-cont

(defn split-with-cont [p coll cont]
  (lazy-seq 
   (when-let [[x & xs] coll]
     (if (not (p x)) (cont coll) (cons x (split-with-cont p xs cont))))))

ほとんどremove-first+と変わらないような定義ですね。
これを利用して速い関数を作ることができれば、
clojureの関数と文法を使ってより良いremove-firstが作れたと言えるような気がします。

(defn remove-first+2 [e coll]
  (split-with-cont #(not (= e %)) coll next))

nextを渡してあげれば良いわけです。
ただし、split-with-contはsplit-withのように過ぎ去った過去*1をみることができません。

あとは実行時間の比較です。

(let [xs (range 1 100000)]
  (time (dotimes [i 1000] (remove-first+ 99999 xs)))
  (time (dotimes [i 1000] (remove-first+2 99999 xs)))
  (time (dotimes [i 1000] (remove-first* 99999 xs))))

;; "Elapsed time: 4.665175 msecs"
;; "Elapsed time: 0.850143 msecs"
;; "Elapsed time: 0.758958 msecs"

少しおそくなったものの結構良い速度なんじゃないでしょうか?

*1:分配束縛した際の左側のリスト