認知情報解析 実習問題

# 修正済み
# 時間がかかる場合は、繰り返し回数を減らしてください。
# 効率が悪いと思うので、適宜変更してください。

temp.state = expand.grid(loc1 = 0:2,loc2=0:2,loc3=0:2,
                         loc4 = 0:2,loc5=0:2,loc6=0:2,
                         loc7 = 0:2,loc8=0:2,loc9=0:2)
temp.state = = expand.grid(rep(list(0:2),9))

n.ones = rowSums(temp.state == 1 )
n.twos = rowSums(temp.state == 2 )
omitTwo = which(n.ones < n.twos)
omitOne = which((n.ones-1 ) > n.twos)
omitUniq = unique(c(omitOne, omitTwo))
state = temp.state[-omitUniq,]
poss.act = apply(state, 1, function(x) which(x==0))

temp.win = matrix(1:9,3)
win.idx = matrix(c(temp.win[1,],temp.win[2,],temp.win[3,],
                   temp.win[,1],temp.win[,2],temp.win[,3],
                   diag(temp.win),
                   diag(temp.win[3:1,])),ncol=3,byrow=T)

idx1 = c()
idx2 = c()
idxCM = c()
for (i.win in 1:nrow(win.idx)){
  idx.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==1)==3)
  idx1 = c(idx1, which(idx.temp))
  idxCM.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==1)==2)
  idxCM = c(idxCM, which(idxCM.temp))
  idx.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==2)==3)
  idx2 = c(idx2, which(idx.temp))
}
n0=apply(state,1,function(x) length((which(x==0))))
tie = which(n0==0)

Q = list()
V = list()
rew.sum = list()
rew.count = list()
policy = list()
for (i.state in 1:nrow(state)){
  Q[[i.state]] =  rep(0,length(poss.act[[i.state]]))
  V[[i.state]] = rep(0,length(poss.act[[i.state]]))
  rew.sum[[i.state]] = rep(0,length(poss.act[[i.state]]))
  rew.count[[i.state]] = rep(0,length(poss.act[[i.state]]))
  policy[[i.state]] = rep(1/length(poss.act[[i.state]]),length(poss.act[[i.state]]))
}

R.W  = 10
R.T  = 5
R.L = -10
gamma = 1
epsilon = 0.05
eta = 1

ck.result <- function(st.idx, idx1, idx2, tie){
  term = F
  rew = 0
  result = "not terminal"
  if (match(st.idx ,idx1, nomatch = 0)!=0){
    rew = R.W
    term = T
    result = "win"
  } else if (match(st.idx ,idx2, nomatch = 0)!=0){
    rew = R.L
    term = T
    result = "lose"
  } else if (match(st.idx ,tie, nomatch = 0)!=0){
    rew = R.T
    term = T
    result = "tie"
  }
  return(list(rew = rew, term = term, result = result))
}

n.rep = 10000
game.hist = rep(0,n.rep)

# main loop
for (i.rep in 1:n.rep){
  st.idx = 1
  term = F
  state.temp = rep(0,9)
  state.hist1 = c()
  state.hist2 = c()
  repeat {
    # playing game
    if (length(poss.act[[st.idx]])==1){
      act1 = poss.act[[st.idx]]
    } else{
      p.act = exp(eta*Q[[st.idx]])/sum(exp(eta*Q[[st.idx]]))
      act1 = sample(poss.act[[st.idx]],1, prob = p.act)
    }
    state.hist1 = rbind(state.hist1,c(st.idx, act1))
    state.temp[act1] = 1
    st.idx = which(apply(state, 1, function(x) sum(x==state.temp) )==9)
    res = ck.result(st.idx, idx1, idx2, tie)
    if (res$term == T){
      rew = res$rew
      break
    }
    p.act = exp(eta*Q[[st.idx]])/sum(exp(eta*Q[[st.idx]]))
    act2 = sample(poss.act[[st.idx]],1, prob = policy[[st.idx]])
    state.hist2 = rbind(state.hist2,c(st.idx, act2))
    state.temp[act2] = 2
    st.idx = which(apply(state, 1, function(x) sum(x==state.temp) )==9)
    res = ck.result(st.idx, idx1, idx2, tie)
    if (res$term == T){
      rew = res$rew
      break
    }
  }
  # update Q & policy
  game.hist[i.rep] = res$result!="lose"
  n.st = nrow(state.hist1)
  #print(res$result)
  if (i.rep%%100==0){print(i.rep)}
  for (i.st in 1:n.st){
    act.idx = which(poss.act[[state.hist1[i.st,1]]]==state.hist1[i.st,2])
    rew.sum[[state.hist1[i.st,1]]][act.idx] = rew.sum[[state.hist1[i.st,1]]][act.idx] + rew
    rew.count[[state.hist1[i.st,1]]][act.idx] = rew.count[[state.hist1[i.st,1]]][act.idx] + 1
    Q[[state.hist1[i.st,1]]][act.idx] = rew.sum[[state.hist1[i.st,1]]][act.idx] / rew.count[[state.hist1[i.st,1]]][act.idx]
  }
}

# plotting results
library(pracma)
game.hist.smooth = movavg(game.hist, 400, type="s")
plot(game.hist.smooth,type='l')