認知情報解析演習 石とりゲーム(bugあり)

col1 = matrix(c(rep(0,4),c(1,0,0,0),c(1,1,0,0),c(1,1,1,0),rep(1,4)),nrow=4,byrow=T)
col2 = matrix(c(rep(10,4),c(11,10,10,10),c(11,11,10,10),c(11,11,11,10),rep(11,4)),nrow=4,byrow=T)
col3 = matrix(c(rep(100,4),c(101,100,100,100),c(101,101,100,100),c(101,101,101,100),rep(101,4)),nrow=4,byrow=T)
act.list = list()
state.temp = list()
counter = 0
Q1 = list()
Q2 = list()
for (i.c1 in 1:5){
  if (sum(col1[,i.c1])==0){
    act1 = c()
  } else {
    act1 = seq(1,sum(col1[,i.c1]),1)
  }
  for (i.c2 in 1:5){
    if (sum(col2[,i.c2])==40){
      act2 = c()
    } else {
      act2 = seq(11,sum(col2[,i.c2]==11)*11,11)
    }
    for (i.c3 in 1:5){
      if (sum(col3[,i.c3])==400){
        act3 = c()
      } else {
        act3 = seq(101,sum(col3[,i.c3]==101)*101,101)
      }
      counter = counter + 1
      state.temp[[counter]] = cbind(col1[,i.c1],col2[,i.c2],col3[,i.c3])
      act.list[[counter]] = c(act1,act2,act3)
      Q1[[counter]] = rep(0, length(c(act1,act2,act3)))
      Q2[[counter]] = rep(0, length(c(act1,act2,act3)))
    }
  }
}
rm.stone <- function(act, st.shape){
  if (act == -99){s}
  if (act > 100){
    n.remove = act%%100
    n.stone = length(which(st.shape[,3]==101))
    start = (5-n.stone)
    st.shape[(start:(start+n.remove-1)),3] = 100
  } else {
    if (act > 10){
      n.remove = act%%10
      n.stone = length(which(st.shape[,2]==11))
      start = (5-n.stone)
      st.shape[(start:(start+n.remove-1)),2] = 10
    } else {
      n.remove = act
      n.stone = length(which(st.shape[,1]==1))
      start = (5-n.stone)
      st.shape[(start:(start+n.remove-1)),1] = 0
    }
  }
  return(st.shape)
}

id.state <- function(st.shape, state.temp){
  for (i.st in 1:125){
    if  (all(st.shape == state.temp[[i.st]])){
      state.idx = i.st 
      break
    }
  }
  return(state.idx)
}

ck.act <- function(Q, act.vec, eta){
  if (is.null(act.vec)){
    return(list(act = -99, act.idx = -99))
    break
  }
  if (length(act.vec)==1){
    act = act.vec
  } else {
    p = exp(Q[[state.idx]])/sum(exp(Q[[state.idx]]))
    act = sample(act.vec, 1, prob = p)
  }
  act.idx = which(act.vec==act)
  return(list(act = act, act.idx = act.idx))
}


gamma=1;alpha = 0.1;n.rep=10000
for (i.rep in 1:n.rep){
  # first action
  state.idx = 125; counter = 1
  st.shape = state.temp[[state.idx]]
  res.act = ck.act(Q1,act.list[[state.idx]],eta)
  act = res.act$act;act.idx = res.act$act.idx
  state.old = state.idx
  act.old = act.idx
  
  # 2nd to last
  while (state.idx != 1) {
    counter = counter + 1
    st.shape <- rm.stone(act, st.shape)
    state.idx <- id.state(st.shape, state.temp)
    if (counter%%2==1) {
      res.act = ck.act(Q1,act.list[[state.idx]],eta)
    } else {
      res.act = ck.act(Q2,act.list[[state.idx]],eta)
    }
    act = res.act$act; act.idx = res.act$act.idx
    if (state.idx == 1){
      if (counter%%2==1){rew1 = -1; rew2 = 1;} else {rew1 = 1; rew2 = -1;}
      Q1[[state.old]][act.old]=Q1[[state.old]][act.old]
         +alpha*(rew1-Q1[[state.old]][act.old])
      Q2[[state.old]][act.old]=Q2[[state.old]][act.old]
         +alpha*rew2-Q2[[state.old]][act.old])
    } else {
      rew1 = 0; 
      rew2 =0;
      if (counter%%2==1){
        Q1[[state.old]][act.old]=Q1[[state.old]][act.old]
          +alpha*(rew1+gamma* Q1[[state.idx]][act.idx]-Q1[[state.old]][act.old])
      } else {
        Q2[[state.old]][act.old]=Q2[[state.old]][act.old]
          +alpha*(rew2+gamma* Q2[[state.idx]][act.idx]-Q2[[state.old]][act.old])
      }
    }
    
    state.old = state.idx
    act.old = act.idx
  }
}