library(tidyverse)
library(gridExtra)
epGreedy = function(nTrial,nRep,epsilon) {
rew.hist = opt.hist = matrix(0,nrow=nTrial,ncol=nRep)
for (i_rep in 1:nRep){
Q.true=rnorm(10); opt.ID=which.max(Q.true)
Q.est = Q.cum = act.count=rep(0,10);
for (i_trial in 1:nTrial) {
if (runif(1) < epsilon) {
action=sample(1:10,1)
} else {
action=which.max(Q.est)
}
rew.hist[i_trial,i_rep]=rnorm(1)+Q.true[action]
opt.hist[i_trial,i_rep]=action==opt.ID
act.count[action]=act.count[action]+1
Q.cum[action]=Q.cum[action]+rew.hist[i_trial,i_rep]
Q.est[action]=Q.cum[action]/act.count[action]
}
}
return(list(opt = opt.hist, rew = rew.hist))
}
n.Trial = 1000; n.Rep = 2000
type1=epGreedy(n.Trial, n.Rep, 0.0)
type2=epGreedy(n.Trial, n.Rep, 0.01)
type3=epGreedy(n.Trial, n.Rep, 0.1)
res.all = tibble(trial = rep(1:nTrial, n.Rep * 3),
rew = c(as.vector(type1$rew),as.vector(type2$rew),as.vector(type3$rew)),
opt = c(as.vector(type1$opt),as.vector(type2$opt),as.vector(type3$opt)),
condition = c(rep("0.00", nTrial * n.Rep),
rep("0.01", nTrial * n.Rep),
rep("0.10", nTrial * n.Rep)))
p1 <- ggplot(res.all) +
geom_smooth(aes(x = trial, y = rew, color = condition)) +
ylab("Average Reward")
p2 <- ggplot(res.all) +
geom_smooth(aes(x = trial, y = opt, color = condition)) +
ylab("Prop. Optimal Action")
grid.arrange(p1, p2, nrow =2)
softmax = function(nTrial, nRep, tau) {
rew.hist = opt.hist = matrix(0,nrow=nTrial,ncol=nRep)
for (i_rep in 1:nRep){
Q.true=rnorm(10); opt.ID=which.max(Q.true)
Q.est = Q.cum = act.count=rep(0,10);
for (i_trial in 1:nTrial) {
p = exp(Q.est/tau)/sum(exp(Q.est)/tau)
action = sample(1:10, 1, prob = p)
rew.hist[i_trial,i_rep]=rnorm(1)+Q.true[action]
opt.hist[i_trial,i_rep]=action==opt.ID
act.count[action]=act.count[action]+1
Q.cum[action]=Q.cum[action]+rew.hist[i_trial,i_rep]
Q.est[action]=Q.cum[action]/act.count[action]
}
}
return(list(opt = opt.hist, rew = rew.hist))
}
n.Trial = 1000; n.Rep = 1000
eG00=epGreedy(n.Trial, n.Rep, 0.0)
eG01=epGreedy(n.Trial, n.Rep, 0.1)
sm=softmax(n.Trial, n.Rep, 0.1)
res.all = tibble(trial = rep(1:n.Trial, n.Rep * 3),
rew = c(as.vector(eG00$rew),as.vector(eG01$rew),as.vector(sm$rew)),
opt = c(as.vector(eG00$opt),as.vector(eG01$opt),as.vector(sm$opt)),
condition = c(rep("epsilon = 0.0", n.Trial * n.Rep),
rep("epsilon = 0.1", n.Trial * n.Rep),
rep("softmax", n.Trial * n.Rep)))
p1 <- ggplot(res.all) +
geom_smooth(aes(x = trial, y = rew, color = condition)) +
ylab("Average Reward")
p2 <- ggplot(res.all) +
geom_smooth(aes(x = trial, y = opt, color = condition)) +
ylab("Prop. Optimal Action")
grid.arrange(p1, p2, nrow =2)
# RL example
V=rep(0,25);
# defining probability matrix
P=matrix(1/4,nrow=25,ncol=4) #
# defining deterministic transition matrix
north=c(2:25,25)
north[ c(5,10,15,20,25)]=c(5,10,15,20,25)
east=c(6:25,21:25)
west=c(1:5,1:20)
south=c(1,1:24)
south[ c(1,6,11,16,21)]=c(1,6,11,16,21)
trM=cbind(north,east,south,west)
trM[10,]=6
trM[20,]=18
# defining reward matrix
R=matrix(0,nrow=25,ncol=4)
R[which(trM==1:25)]=-1
R[10,]=10
R[20,]=5
# calculating state-values iteratively
# fixed number of iterations
nRep=1000; gamma=0.9;
for (i_rep in 1:nRep) {
V.old=V
for (i_state in 1:25) {
V[i_state]=sum(P[i_state,]*(R[i_state,]+gamma*V.old[trM[i_state,]]))
}
}
print(matrix(V,nrow=5)[5:1,])
# jisshu 1
north=c(1:3,15,1:10)
east=2:15;east[ c(3,7,11)]=c(3,7,11)
south=c(5:15,12:14)
west=c(15,1:13);west[ c(4,8,12)]=c(4,8,12)
trM=cbind(north,east,south,west)
# defining Reward & trans. prob.
R=-1;P=0.25;
# policy improvement
north=c(1:3,15,1:10)
east=2:15;east[ c(3,7,11)]=c(3,7,11)
south=c(5:15,12:14)
west=c(15,1:13);west[ c(4,8,12)]=c(4,8,12)
trM=cbind(north,east,south,west)
R=-1;P=0.25;V=rep(0,15);
delta=1; gamma=1; tol=1e-10;
bestP=sample(1:4,14,replace=T)
stable=F;counter=0;
while (stable==F){
counter=counter+1
# iterative policy evaluation
while (delta>tol) {
delta=0;
for (i_state in 1:14) {
v=V[i_state]
V[i_state]=sum(P*(R+gamma*V[trM[i_state,]]))
delta=max(delta,abs(v-V[i_state]))
}
}
# policy improvement
stable=F
for (i_state in 1:14) {
b=bestP[i_state]
bestP[i_state]=which.max(V[trM[i_state,]])
ifelse((bestP[i_state]==b),stable<-T,stable<-F)
}
}
Related