プログラミングを上達させたい

情報学専攻の大学院→放送局でCMの営業など@大阪→舞台俳優&IT営業@東京

LCAとダブリング(ABC014D問題) with Pascal

ずっとつまづいていたAtCoderBeginnerContest014の問題Dが解けました。のと、それに伴い久々にPascalでガッツリコードを書いたので、メモ。


この問題は、
1.まず隣接リストに情報を入れていく
2.それを基に根付き木を作る(根はどれにしてもよい。今回は一番番号が小さいノードを根としています)
3.ダブリング
4.LCAを求めて、答えを出す
というアルゴリズムで解けます。
上のアルゴリズムを、解説を見ながらJavaで書いてみたものの、TLEになってしまいました。
他のJavaでの正解者を見ても、結構2000msギリギリになっている方が多かったです。
そこで(Javaの高速化とか分からないから)、速い言語で書いちゃえ、という方針に変えました。
しかし、僕が使えるのはPascalとSchemeとJavaのみです。C++は使わないスタイルで生きていきたいので、多分挑戦しません。
SchemeもAtCoderで使った感じでは大分遅いので、Pascalで挑戦することにしました。

とりあえず説明のため、まずはJavaで書いてTLEになったコードを書きます。

import java.util.*;
 
class oneNodeList extends LinkedList<Integer>{
}
 
public class Main {
  public static void main(String[] args){
    Scanner sc = new Scanner(System.in);
 
    int n = sc.nextInt();
    oneNodeList[] rinsetsu = new oneNodeList[n];
    for(int i = 0; i < n; i++){
      rinsetsu[i] = new oneNodeList();
    }
    for(int i = 1; i < n; i++){
      int tan1 = sc.nextInt() - 1;
      int tan2 = sc.nextInt() - 1;
      rinsetsu[tan1].add(tan2);
      rinsetsu[tan2].add(tan1);
    }
 
 
    int[] treedepth = new int[n];
    int[] treeparent = new int[n];
    LinkedList<Integer> indexqueue = new LinkedList<Integer>();
    LinkedList<Integer> depthqueue = new LinkedList<Integer>();
    boolean[] check = new boolean[n];
    for(int i = 1; i < n; i++){
      check[i] = false;
    }
    check[0] = true;
    indexqueue.add(0);
    depthqueue.add(0);
    treedepth[0] = 0;
    treeparent[0] = -1;
    int nowdep;
    int nowind;
    int newind;
    while(!indexqueue.isEmpty()){
      nowdep = depthqueue.poll();
      nowind = indexqueue.poll();
      while(!rinsetsu[nowind].isEmpty()){
        newind = rinsetsu[nowind].poll();
        if(!check[newind]){
          check[newind] = true;
          depthqueue.add(nowdep + 1);
          indexqueue.add(newind);
          treedepth[newind] = nowdep + 1;
          treeparent[newind] = nowind;
        }
      }
    }
 
    int[][] doubling = new int[n][20];
    //doubling[i][j]には、ノードiから2^j分遡ったノードの番号が入る
    for(int i = 0; i < n; i++){
      doubling[i][0] = treeparent[i];
    }
    for(int j = 1; j < 20; j++){
      for(int i = 0; i < n; i++){
        if(doubling[i][j-1] < 0){
          doubling[i][j] = -1;
        }else{
          doubling[i][j] = doubling[doubling[i][j-1]][j-1];
        }
      }
    }
 
    int q = sc.nextInt();
    int katahou, moukata, kata1, kata2, sa, upcount, counter, lcaind;
    for(int casenum = 0; casenum < q; casenum++){
      katahou = sc.nextInt() - 1;
      moukata = sc.nextInt() - 1;
      kata1 = katahou;
      kata2 = moukata;
      //lca を求める
      //まず、深さを同じにする
      if(treedepth[kata1] < treedepth[kata2]){
        int keeping = kata2;
        kata2 = kata1;
        kata1 = keeping;
      }
      if(treedepth[kata1] > treedepth[kata2]){
        while(treedepth[kata1] > treedepth[kata2]){
          if(treedepth[kata1] - treedepth[kata2] == 1){
            kata1 = treeparent[kata1];
            break;
          }else{
            upcount = 1;
            counter = 0;
            sa = treedepth[kata1] - treedepth[kata2];
            while(upcount <= sa){
              upcount *= 2;
              counter++;
            }
            kata1 = doubling[kata1][counter - 1];
          }
        }
      }
      if(treedepth[kata1] != treedepth[kata2]){System.out.println("Error!!!");}
 
      //lcaを求める
      if(kata1 == kata2){
        lcaind = kata1;
      }else{
        while(treeparent[kata1] != treeparent[kata2]){
          upcount = 1;
          while(doubling[kata1][upcount] != doubling[kata2][upcount]){
            upcount++;
          }
          kata1 = doubling[kata1][upcount-1];
          kata2 = doubling[kata2][upcount-1];
        }
        if(treeparent[kata1] < 0){System.out.println("Error!!!");}
        lcaind = treeparent[kata1];
      }
      // System.out.println(lcaind);
      System.out.println(treedepth[katahou] + treedepth[moukata] - 2 * treedepth[lcaind] + 1);
 
    }
 
  }
}

なっが。長い。

ただ、これをPascalで書くとなると、隣接リストがまだ書けないのです。
というわけで、今回はそれを学びました。
とりあえず、別に隣接リストはキューでもスタックでもいいので、今回はスタックを実装しました。
参考にしたのは、コチラのページです。

スタックを実装するのには、record型を使います。要素と、次のリンクを指ししめるポインタ?でやるんですね。type宣言部にこのように書きます。

type
	link = ^data;
	data = record
		ele : Longint;
		next : link;
	end;

そして、こんな風にやっていきます。

(*stack、youso がリンク型の変数とする*)
(*リンク型は上で定義した通り*)

(*リンク型の変数を生成する*)
new(stack); new(youso);

(*5を先頭に追加する=push(5)*)
youso^.next := stack;
youso^.ele := 5;
stack := youso;

(*先頭を取り出して、aという変数に入れる=pop()*)
a := stack^.ele;
stack := stack^.next;

ちなみに、スタック単体に対してではなく、スタックの配列に対して操作をしたいときも、同じようにやります。つまり、

rinlist : array[1..100000] of link;

と宣言した上で、

new(rinlist[i]);
rinlist[i]^.next := nil;

で生成、pushやpopも同じように出来ます。

上のようにして色々新しい技術を使って書いたコードがコチラ。提出したらちゃんとACになりました。すごいぞPascal!

program solve(input, output);
 
type
	link = ^data;
	data = record
		ele : Longint;
		next : link;
	end;
 
function giri2(target : Longint):Longint;
var counter, nownum : Longint;
begin
	counter := -1;
	nownum := 1;
	while nownum <= target do
	begin
		counter := counter + 1;
		nownum := nownum * 2;
	end;
	giri2 := counter;
end;
 
var
	n, q, i, j, sa, golen, keeping, katahou, moukatahou, nowind, nowdep, newind, lca, tan1, tan2, kata1, kata2: Longint;
	youso1, youso2 : link;
	indlist, deplist, list1, list2 : link;
	rinlist : array[1..100000] of link;
	notcheck : array[1..100000] of Boolean;
	(*check は、すでに見たやつにfalseが入る*)
	netsukigi : array[1..100000, 1..2] of Longint;
	(*netsukigi の1の方には深さが、2の方には親の番号が入る*)
	doubling : array[1..100000, 0..18] of Longint;
	(*doubling[i][j]には、iから2^j分登ったところの番号が入る*)
	answers : array[1..100000] of Longint;
 
begin
	readln(n);
	for i := 1 to n do begin
		new(rinlist[i]);
		rinlist[i]^.next := nil;
		notcheck[i] := True;
	end;
	for i := 1 to n-1 do begin
		read(katahou); readln(moukatahou);
		new(youso1);
		youso1^.next := rinlist[katahou];
		youso1^.ele := moukatahou;
		rinlist[katahou] := youso1;
		new(youso2);
		youso2^.next := rinlist[moukatahou];
		youso2^.ele := katahou;
		rinlist[moukatahou] := youso2;
	end;
 
	(*番号1を根とします*)
	new(indlist); new(deplist);
	indlist^.next := nil; deplist^.next := nil;
	new(list1); new(list2);
	list1^.next := indlist;
	list1^.ele := 1;
	indlist := list1;
 
	list2^.next := deplist;
	list2^.ele := 0;
	deplist := list2;
 
	notcheck[1] := False;
	netsukigi[1,1] := 0; netsukigi[1,2] := -1;
 
	while (indlist^.next <> nil) do
	begin
		nowind := indlist^.ele;
		nowdep := deplist^.ele;
		indlist := indlist^.next;
		deplist := deplist^.next;
		while (rinlist[nowind]^.next <> nil) do
		begin
			newind := rinlist[nowind]^.ele;
			rinlist[nowind] := rinlist[nowind]^.next;
			if notcheck[newind] then begin
				notcheck[newind] := False;
				netsukigi[newind][1] := nowdep + 1;
				netsukigi[newind][2] := nowind;
				new(youso1); new(youso2);
				youso1^.next := indlist;
				youso1^.ele := newind;
				indlist := youso1;
				youso2^.next := deplist;
				youso2^.ele := nowdep + 1;
				deplist := youso2;
			end;
		end;
	end;
 
	for i := 1 to n do begin
		doubling[i][0] := netsukigi[i][2];
	end;
 
	for j := 1 to 18 do begin
		for i := 1 to n do begin
			if (doubling[i][j-1] < 0) then begin
				doubling[i][j] := -1;
			end
			else begin
				doubling[i][j] := doubling[doubling[i][j-1]][j-1];
			end;
		end;
	end;
 
	read(q);
	for i := 1 to q do begin
		read(tan1); readln(tan2);
		kata1 := tan1; kata2 := tan2;
		if netsukigi[kata1][1] < netsukigi[kata2][1] then begin
			keeping := kata1; kata1 := kata2; kata2 := keeping;
		end;
		(*kata1 の方がkata2 より深いところにある*)
		if netsukigi[kata1][1] > netsukigi[kata2][1] then begin
			while netsukigi[kata1][1] > netsukigi[kata2][1] do
			begin
				sa := netsukigi[kata1][1] - netsukigi[kata2][1];
				golen := giri2(sa);
				kata1 := doubling[kata1][golen];
			end;
		end;
 
		(*同じ深さに来ているハズ*)
		if kata1 = kata2 then begin
			lca := kata1
		end
		else begin
			while netsukigi[kata1][2] <> netsukigi[kata2][2] do
			begin
				golen := 0;
				while doubling[kata1][golen+1] <> doubling[kata2][golen+1] do
				begin
					golen := golen + 1;
				end;
				kata1 := doubling[kata1][golen]; kata2 := doubling[kata2][golen];
			end;
			lca := netsukigi[kata1][2];
		end;
 
		answers[i] := netsukigi[tan1][1] + netsukigi[tan2][1] - 2 * netsukigi[lca][1] + 1;
	end;
 
	for i := 1 to q do begin
		writeln(answers[i]);
	end;
	
 
end.

ちなみに、実行時間は356msでした。他の方のC++のコードより速かったりと、Pascalのすごさを見せてくれました。

また、上のコードでは、スタックに対するnewやpush、pop、isEmpty(JavaのLinkedListで普段僕が使うもので言うところのnew、push、poll、isEmpty)を毎回ちゃんと書いています。
せっかくなので、この4つの手続き(及び関数)を、procedureやfunctionの形でまとめたものも作ってみました。
また、それについて注意なのですが、

function f(inp : Integer): Integer;

と書くと、引数であるinpは、値だけもらう感じ?になるのですが、次のように

function f(var inp : Integer): Integer;

と書くと、inpはそのまま実際のinpが使われます。

だから、

function f(var inp : Integer): Integer;
begin
inp := inp + 1;
f := inp;

と書くと、実際に関数fの引数に与えた変数の値が、関数適用後に1増えてる、ということになります。
仮引数や実引数?値渡し?参照渡し?みたいな話ですよね。
詳しくは(面倒なので)調べてないのですが、とりあえずこのことは今まで分かっていなかったので、収穫です。

新しく付け加えたpopやらを使うと、stackという名前のスタックを作って、4を入れ、5を入れ、先頭(=5)を取り出して変数aに入れ、空かどうか判定する、というのは、次のように簡単に書けます。

create(stack);
push(stack, 4);
push(stack, 5);
a := pop(stack);
if notEmpty(stack) then writeln('stack is not empty');

スマートですね。よしよし。

さて、ではpushやpopも付け加えたバージョンを以下に貼って、今回は終わりにします。
これからもPascal使おうかなー。
以上です。

program solve(input, output);
 
type
	link = ^data;
	data = record
		ele : Longint;
		next : link;
	end;
 
function giri2(target : Longint):Longint;
var counter, nownum : Longint;
begin
	counter := -1;
	nownum := 1;
	while nownum <= target do
	begin
		counter := counter + 1;
		nownum := nownum * 2;
	end;
	giri2 := counter;
end;
 
procedure push(var list : link; element : Longint);
var newlist : link;
begin
	new(newlist);
	newlist^.next := list;
	newlist^.ele := element;
	list := newlist;
end;
 
function pop(var list : link):Longint;
var res : Longint;
begin
	res := list^.ele;
	list := list^.next;
	pop := res; 
end;
 
function notEmpty(list : link):Boolean;
begin
	if list^.next <> nil then begin
		notEmpty := True;
	end
	else begin
		notEmpty := False;
	end;
end;
 
procedure create(var list : link);
begin
	new(list);
	list^.next := nil;
end;
 
 
var
	n, q, i, j, sa, golen, keeping, katahou, moukatahou, nowind, nowdep, newind, lca, tan1, tan2, kata1, kata2: Longint;
	indlist, deplist : link;
	rinlist : array[1..100000] of link;
	notcheck : array[1..100000] of Boolean;
	(*check は、すでに見たやつにfalseが入る*)
	netsukigi : array[1..100000, 1..2] of Longint;
	(*netsukigi の1の方には深さが、2の方には親の番号が入る*)
	doubling : array[1..100000, 0..18] of Longint;
	(*doubling[i][j]には、iから2^j分登ったところの番号が入る*)
	answers : array[1..100000] of Longint;
 
begin
	readln(n);
	for i := 1 to n do begin
		create(rinlist[i]);
		notcheck[i] := True;
	end;
	for i := 1 to n-1 do begin
		read(katahou); readln(moukatahou);
		push(rinlist[katahou], moukatahou);
		push(rinlist[moukatahou], katahou);
	end;
 
	(*番号1を根とします*)
	create(indlist); create(deplist);
	push(indlist, 1); push(deplist, 0);
 
	notcheck[1] := False;
	netsukigi[1,1] := 0; netsukigi[1,2] := -1;
 
	while notEmpty(indlist) do
	begin
		nowind := pop(indlist);
		nowdep := pop(deplist);
		while notEmpty(rinlist[nowind]) do
		begin
			newind := pop(rinlist[nowind]);
			if notcheck[newind] then begin
				notcheck[newind] := False;
				netsukigi[newind][1] := nowdep + 1;
				netsukigi[newind][2] := nowind;
				push(indlist, newind); push(deplist, nowdep + 1);
			end;
		end;
	end;
 
	for i := 1 to n do begin
		doubling[i][0] := netsukigi[i][2];
	end;
 
	for j := 1 to 18 do begin
		for i := 1 to n do begin
			if (doubling[i][j-1] < 0) then begin
				doubling[i][j] := -1;
			end
			else begin
				doubling[i][j] := doubling[doubling[i][j-1]][j-1];
			end;
		end;
	end;
 
	read(q);
	for i := 1 to q do begin
		read(tan1); readln(tan2);
		kata1 := tan1; kata2 := tan2;
		if netsukigi[kata1][1] < netsukigi[kata2][1] then begin
			keeping := kata1; kata1 := kata2; kata2 := keeping;
		end;
		(*kata1 の方がkata2 より深いところにある*)
		if netsukigi[kata1][1] > netsukigi[kata2][1] then begin
			while netsukigi[kata1][1] > netsukigi[kata2][1] do
			begin
				sa := netsukigi[kata1][1] - netsukigi[kata2][1];
				golen := giri2(sa);
				kata1 := doubling[kata1][golen];
			end;
		end;
 
		(*同じ深さに来ているハズ*)
		if kata1 = kata2 then begin
			lca := kata1
		end
		else begin
			while netsukigi[kata1][2] <> netsukigi[kata2][2] do
			begin
				golen := 0;
				while doubling[kata1][golen+1] <> doubling[kata2][golen+1] do
				begin
					golen := golen + 1;
				end;
				kata1 := doubling[kata1][golen]; kata2 := doubling[kata2][golen];
			end;
			lca := netsukigi[kata1][2];
		end;
 
		answers[i] := netsukigi[tan1][1] + netsukigi[tan2][1] - 2 * netsukigi[lca][1] + 1;
	end;
 
	for i := 1 to q do begin
		writeln(answers[i]);
	end;
	
 
end.