# 修正済み
# 時間がかかる場合は、繰り返し回数を減らしてください。
# 効率が悪いと思うので、適宜変更してください。
temp.state = expand.grid(loc1 = 0:2,loc2=0:2,loc3=0:2,
loc4 = 0:2,loc5=0:2,loc6=0:2,
loc7 = 0:2,loc8=0:2,loc9=0:2)
temp.state = = expand.grid(rep(list(0:2),9))
n.ones = rowSums(temp.state == 1 )
n.twos = rowSums(temp.state == 2 )
omitTwo = which(n.ones < n.twos)
omitOne = which((n.ones-1 ) > n.twos)
omitUniq = unique(c(omitOne, omitTwo))
state = temp.state[-omitUniq,]
poss.act = apply(state, 1, function(x) which(x==0))
temp.win = matrix(1:9,3)
win.idx = matrix(c(temp.win[1,],temp.win[2,],temp.win[3,],
temp.win[,1],temp.win[,2],temp.win[,3],
diag(temp.win),
diag(temp.win[3:1,])),ncol=3,byrow=T)
idx1 = c()
idx2 = c()
idxCM = c()
for (i.win in 1:nrow(win.idx)){
idx.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==1)==3)
idx1 = c(idx1, which(idx.temp))
idxCM.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==1)==2)
idxCM = c(idxCM, which(idxCM.temp))
idx.temp = apply(state, 1, function(x) sum(x[win.idx[i.win,]]==2)==3)
idx2 = c(idx2, which(idx.temp))
}
n0=apply(state,1,function(x) length((which(x==0))))
tie = which(n0==0)
Q = list()
V = list()
rew.sum = list()
rew.count = list()
policy = list()
for (i.state in 1:nrow(state)){
Q[[i.state]] = rep(0,length(poss.act[[i.state]]))
V[[i.state]] = rep(0,length(poss.act[[i.state]]))
rew.sum[[i.state]] = rep(0,length(poss.act[[i.state]]))
rew.count[[i.state]] = rep(0,length(poss.act[[i.state]]))
policy[[i.state]] = rep(1/length(poss.act[[i.state]]),length(poss.act[[i.state]]))
}
R.W = 10
R.T = 5
R.L = -10
gamma = 1
epsilon = 0.05
eta = 1
ck.result <- function(st.idx, idx1, idx2, tie){
term = F
rew = 0
result = "not terminal"
if (match(st.idx ,idx1, nomatch = 0)!=0){
rew = R.W
term = T
result = "win"
} else if (match(st.idx ,idx2, nomatch = 0)!=0){
rew = R.L
term = T
result = "lose"
} else if (match(st.idx ,tie, nomatch = 0)!=0){
rew = R.T
term = T
result = "tie"
}
return(list(rew = rew, term = term, result = result))
}
n.rep = 10000
game.hist = rep(0,n.rep)
# main loop
for (i.rep in 1:n.rep){
st.idx = 1
term = F
state.temp = rep(0,9)
state.hist1 = c()
state.hist2 = c()
repeat {
# playing game
if (length(poss.act[[st.idx]])==1){
act1 = poss.act[[st.idx]]
} else{
p.act = exp(eta*Q[[st.idx]])/sum(exp(eta*Q[[st.idx]]))
act1 = sample(poss.act[[st.idx]],1, prob = p.act)
}
state.hist1 = rbind(state.hist1,c(st.idx, act1))
state.temp[act1] = 1
st.idx = which(apply(state, 1, function(x) sum(x==state.temp) )==9)
res = ck.result(st.idx, idx1, idx2, tie)
if (res$term == T){
rew = res$rew
break
}
p.act = exp(eta*Q[[st.idx]])/sum(exp(eta*Q[[st.idx]]))
act2 = sample(poss.act[[st.idx]],1, prob = policy[[st.idx]])
state.hist2 = rbind(state.hist2,c(st.idx, act2))
state.temp[act2] = 2
st.idx = which(apply(state, 1, function(x) sum(x==state.temp) )==9)
res = ck.result(st.idx, idx1, idx2, tie)
if (res$term == T){
rew = res$rew
break
}
}
# update Q & policy
game.hist[i.rep] = res$result!="lose"
n.st = nrow(state.hist1)
#print(res$result)
if (i.rep%%100==0){print(i.rep)}
for (i.st in 1:n.st){
act.idx = which(poss.act[[state.hist1[i.st,1]]]==state.hist1[i.st,2])
rew.sum[[state.hist1[i.st,1]]][act.idx] = rew.sum[[state.hist1[i.st,1]]][act.idx] + rew
rew.count[[state.hist1[i.st,1]]][act.idx] = rew.count[[state.hist1[i.st,1]]][act.idx] + 1
Q[[state.hist1[i.st,1]]][act.idx] = rew.sum[[state.hist1[i.st,1]]][act.idx] / rew.count[[state.hist1[i.st,1]]][act.idx]
}
}
# plotting results
library(pracma)
game.hist.smooth = movavg(game.hist, 400, type="s")
plot(game.hist.smooth,type='l')
Related