RL Sutton & Barto Ch.5

########################################
#   ch5.1 Monte Carlo policy evaluation 
########################################
#
# reference: http://waxworksmath.com
#
########################################
# function to calc. value of cards
card.value<-function(adj.cards) {
 sum.cards=sum(adj.cards)
 if (any(adj.cards==1) & sum.cards<=11) {
   sum.cards=sum.cards+10;
   usableA=1          #true
  } else {usableA=2}  #false
 return(c(sum.cards,usableA))
}

# function to calc. reward
calc.reward<-function(p.val,d.val) {
  if (p.val>21) { reward=-1
  } else {if (d.val>21) { reward=1
    } else {if (p.val==d.val) {reward=0
      } else{ reward=ifelse(p.val>d.val,1,-1)}
}}}

# main function
BJ_MC_fixedPolicy<-function(policy=20,maxIter=1e6){
  rew.sum=array(0,dim=c(10,10,2))
  rew.count=array(0,dim=c(10,10,2))
  for (i_play in 1:maxIter) {
    cards=sample(rep(1:13,4))
    player=cards[1:2];  adj.player=pmin(player,10)
    dealer=cards[3:4];  adj.dealer=pmin(dealer,10)
    cards=cards[-(1:4)]
    d.val=card.value(adj.dealer)
    p.val=card.value(adj.player)
    state.hist=c(adj.dealer[1],p.val[1],p.val[2])
    while (p.val[1] < policy) {
      player=c(player,cards[1]); adj.player=pmin(player,10)
      cards=cards[-1]
      p.val=card.value(adj.player)
      state.hist=rbind(state.hist,c(adj.dealer[1],p.val[1],p.val[2]))
    }
    while (d.val[1] < 17) {
      dealer=c(dealer,cards[1]); adj.dealer=pmin(dealer,10)
      cards=cards[-1]
      d.val=card.value(adj.dealer)
    }
    rew=calc.reward(p.val[1],d.val[1])
    n.state=nrow(state.hist)
    if (is.null(n.state)) {
      n.state=1
      state.hist=t(as.matrix(state.hist))
    }
    for (i_state in 1:n.state) {
      if (state.hist[i_state,2] > 11 & state.hist[i_state,2] < 22) {
        rew.sum[state.hist[i_state,1],(state.hist[i_state,2]-11),state.hist[i_state,3]]
          = rew.sum[state.hist[i_state,1],(state.hist[i_state,2]-11),state.hist[i_state,3]]+rew
        rew.count[state.hist[i_state,1],(state.hist[i_state,2]-11),state.hist[i_state,3]]
          =rew.count[state.hist[i_state,1],(state.hist[i_state,2]-11),state.hist[i_state,3]]+1
      }
    }
  }
  return(rew.sum/rew.count)
}

# function 2 plot results
plot.BJ_MC<-function(V){
 par(mfrow=c(1,2))
 image(V[,,1],main="with usable Ace",xaxt='n',yaxt='n',
   xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
 image(V[,,2],main="without usable Ace",xaxt='n',yaxt='n',
   xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
}

> V=BJ_MC(17)
> plot.BJ_MC(V)

RL_exp5_1

########################################
#   ch5.3 Monte Carlo exploring starts
########################################
# function to calc. value of cards
card.value<-function(adj.cards) {
 sum.cards=sum(adj.cards)
 if (any(adj.cards==1) & sum.cards<=11) {
   sum.cards=sum.cards+10;
   usableA=1             #true
  } else {usableA=2}     #false
 return(c(sum.cards,usableA))
}

# function to calc. reward
calc.reward<-function(p.val,d.val) {
  if (p.val>21) { reward=-1
  } else {if (d.val>21) { reward=1
    } else {if (p.val==d.val) {reward=0
      } else{ reward=ifelse(p.val>d.val,1,-1)}
}}}

# main function
BJ_MC<-function(maxIter=1e6){
  rew.sum=array(0,dim=c(10,10,2,2))
  rew.count=array(1,dim=c(10,10,2,2))
  Q=array(0,dim=c(10,10,2))
  V=array(0,dim=c(10,10,2))
  policy=array(sample(0:1,10*10*2,replace=T),dim=c(10,10,2))
  # policy: 1 = hit, 0 = stay
  for (i_play in 1:maxIter) {
    # initial draw
    cards=sample(c(rep(1:10,4),rep(10,12)))
    player=cards[1:2]
    dealer=cards[3:4]
    cards=cards[-(1:4)]
    d.val=card.value(dealer)
    p.val=card.value(player)
     
    while( p.val[1] < 12 ) {
      player=c(player,cards[1])
      cards=cards[-1]
      p.val=card.value(player)
    }
    action=sample(0:1,1)
    state.hist=c(dealer[1],p.val[1],p.val[2],(action+1))
    
    # player's action
    while (action==1 & p.val[1]<22) {
      player=c(player,cards[1])
      cards=cards[-1]
      p.val=card.value(player)
      state.hist=rbind(state.hist,c(dealer[1],p.val[1],p.val[2],(action+1)))
      if (p.val[1]<22) {
        action=policy[dealer[1],(p.val[1]-11),p.val[2]]
      }
    }
    
    # dealer's action
    while (d.val[1]<17) {
      dealer=c(dealer,cards[1])
      cards=cards[-1]
      d.val=card.value(dealer)
    }
    rew=calc.reward(p.val[1],d.val[1])
    n.state=nrow(state.hist)
    if (is.null(n.state)) {
      n.state=1
      state.hist=t(as.matrix(state.hist))
    }
    for (i_state in 1:n.state) {
      if (state.hist[i_state,2]>11 & state.hist[i_state,2]<22) {
        ind=state.hist[i_state,]-c(0,11,0,0)
        rew.sum[ind[1],ind[2],ind[3],ind[4]]= rew.sum[ind[1],ind[2],ind[3],ind[4]]+rew
        rew.count[ind[1],ind[2],ind[3],ind[4]]=rew.count[ind[1],ind[2],ind[3],ind[4]]+1
        Q=rew.sum/rew.count;
        policy[,,1]=Q[,,1,1] < Q[,,1,2]
        policy[,,2]=Q[,,2,1] < Q[,,2,2]
      }
    }
  }
  V[,,1]=(rew.sum[,,1,1]+rew.sum[,,1,2])/(rew.count[,,1,1]+rew.count[,,1,2])
  V[,,2]=(rew.sum[,,2,1]+rew.sum[,,2,2])/(rew.count[,,2,1]+rew.count[,,2,2])
  return(list(policy,V,Q))
}


# function 2 plot results
plot.BJ_MC2<-function(V){
 par(mfrow=c(2,2))
 image(V[[2]][,,1],main="Utility with usable Ace",xaxt='n',yaxt='n',
  xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
 image(V[[2]][,,2],main="Utility without usable Ace",xaxt='n',yaxt='n',
  xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
 image(V[[1]][,,1],main="Policy with usable Ace",xaxt='n',yaxt='n',
  xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
 image(V[[1]][,,2],main="Policy without usable Ace",xaxt='n',yaxt='n',
  xlab="Dealer showing",ylab="Player sum")
 axis(1,at=seq(0,1,length.out=10),label=c("A",paste(2:10)))
 axis(2,at=seq(0,1,length.out=10),label=12:21)
}

V=BJ_MC(5e6)
plot.BJ_MC2(V)

RL_exp5_3

Leave a Reply