ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • double DQN 에 대한 정보 및 source code (python) 분석 (2)
    python source code 2023. 9. 20. 22:17
    728x90

    참고 2에 대한 예제 분석

     

     

    
    
    class DeepQNetworkAgent {
      /// The Q-network uses to estimate the action values.
      var qNet: DeepQNetwork ########### action 값 추정을 위한 네트워크 
      
      /// The copy of the Q-network updated less frequently to stabilize the
      /// training process.
      
      var targetQNet: DeepQNetwork ########### 
      
      /// The optimizer used to train the Q-network.
      let optimizer: Adam<DeepQNetwork>
      let replayBuffer: ReplayBuffer
      
      /// The discount factor that measures how much to weight to give to future
      let discount: Float
      /// The minimum replay buffer size before the training starts.
      let minBufferSize: Int
      let doubleDQN: Bool
      let device: Device
    
      init(
        qNet: DeepQNetwork,
        targetQNet: DeepQNetwork,
        optimizer: Adam<DeepQNetwork>,
        replayBuffer: ReplayBuffer,
        discount: Float,
        minBufferSize: Int,
        doubleDQN: Bool,
        device: Device
      ) {
        self.qNet = qNet ## 일반적인 network / get action에 활용
        self.targetQNet = targetQNet
        self.optimizer = optimizer
        self.replayBuffer = replayBuffer
        self.discount = discount
        self.minBufferSize = minBufferSize
        self.doubleDQN = doubleDQN
        self.device = device
    
        // Copy Q-network to Target Q-network before training
        updateTargetQNet(tau: 1)
      }
    
      func getAction(state: Tensor<Float>, epsilon: Float) -> Tensor<Int32> {
        if Float(np.random.uniform()).unwrapped() < epsilon {
          return Tensor<Int32>(numpy: np.array(np.random.randint(0, 2), dtype: np.int32))!
        } 
        
        else {
          // Neural network input needs to be 2D
          let tfState = Tensor<Float>(numpy: np.expand_dims(state.makeNumpyArray(), axis: 0))!
          let qValues = qNet(tfState)[0]
          return Tensor<Int32>(qValues[1].scalarized() > qValues[0].scalarized() ? 1 : 0, on: device)
        }
      }
    
    
      func train(batchSize: Int) -> Float {
        // Don't train if replay buffer is too small
        if replayBuffer.count >= minBufferSize {
          let (tfStateBatch, tfActionBatch, tfRewardBatch, tfNextStateBatch, tfIsDoneBatch) =
            replayBuffer.sample(batchSize: batchSize)
    
          let (loss, gradients) = valueWithGradient(at: qNet) { qNet -> Tensor<Float> in
            // Compute prediction batch
            let npActionBatch = tfActionBatch.makeNumpyArray()
            let npFullIndices = np.stack(
              [np.arange(batchSize, dtype: np.int32), npActionBatch], axis: 1)
            let tfFullIndices = Tensor<Int32>(numpy: npFullIndices)!
            let stateQValueBatch = qNet(tfStateBatch)
            let predictionBatch = stateQValueBatch.dimensionGathering(atIndices: tfFullIndices)
    
            // Compute target batch
            let nextStateQValueBatch: Tensor<Float>
            if self.doubleDQN == true {
              // Double DQN
              let npNextStateActionBatch = self.qNet(tfNextStateBatch).argmax(squeezingAxis: 1)
                .makeNumpyArray()
                ####### qnet으로부터 next state에 대한 argmax 행동 추출
                
              let npNextStateFullIndices = np.stack(
                [np.arange(batchSize, dtype: np.int32), npNextStateActionBatch], axis: 1)
              let tfNextStateFullIndices = Tensor<Int32>(numpy: npNextStateFullIndices)!
              nextStateQValueBatch = self.targetQNet(tfNextStateBatch).dimensionGathering(
                atIndices: tfNextStateFullIndices)
                ######### target net으로부터 next state에 대한 Q value 추출 from qnet의 추출된 행동 기준
                
            } else {
              // DQN
              nextStateQValueBatch = self.targetQNet(tfNextStateBatch).max(squeezingAxes: 1)
            }
            let targetBatch: Tensor<Float> =
              tfRewardBatch + self.discount * (1 - Tensor<Float>(tfIsDoneBatch)) * nextStateQValueBatch
    
            return huberLoss(
              predicted: predictionBatch,
              expected: targetBatch,
              delta: 1
            )
          }
          optimizer.update(&qNet, along: gradients)
    
          return loss.scalarized()
        }
        return 0
      }
    
      func updateTargetQNet(tau: Float) {
        self.targetQNet.l1.weight =
          tau * Tensor<Float>(self.qNet.l1.weight) + (1 - tau) * self.targetQNet.l1.weight
        self.targetQNet.l1.bias =
          tau * Tensor<Float>(self.qNet.l1.bias) + (1 - tau) * self.targetQNet.l1.bias
        self.targetQNet.l2.weight =
          tau * Tensor<Float>(self.qNet.l2.weight) + (1 - tau) * self.targetQNet.l2.weight
        self.targetQNet.l2.bias =
          tau * Tensor<Float>(self.qNet.l2.bias) + (1 - tau) * self.targetQNet.l2.bias
      }
    }

     

    위 코드에서 

    qnet : next 상태 max Q 가 되는 행동 추출

    target net : 추출된 행동에 대한 next 상태의 Q value 추출

     

    이후, target net 업데이트

     

     

    Double DQN에서는 maximum을 고르기 위한 Q들을 구할 때, 다음 state에 대해 Q value를 maximize하는 action을 target network가 아닌, DQN network에서 구하고나서 maximum Q를 target network에서 계산하는 방식을 따릅니다.

     

    DQN network -> action (when : Q value max - at (t+1))target network -> calculate Max Q value

     

     

     

     

     

     

     

     

     

    참고 :

     

    https://www.kaggle.com/code/abechanta/dqn-and-its-successors-with-openai-gym-1

     

    DQN and its successors with OpenAI Gym (1)

    Explore and run machine learning code with Kaggle Notebooks | Using data from No attached data sources

    www.kaggle.com

     

     

     

    2개 모델 생성에 대한 double dqn 예제

     

    https://github.com/rlcode/reinforcement-learning/blob/master/2-cartpole/2-double-dqn/cartpole_ddqn.py

     

    https://aggregata.de/en/blog/reinforcement-learning/d2qn/

     

     

    반응형
Designed by Tistory.