Deep Q-Network (DQN)

"Saat Q-Learning Bertemu Deep Learning"

1. Kenapa butuh Neural Network?

Di Q-Learning biasa, kita pakai Tabel (Q-Table). Kalau state-nya cuma 16 kotak (FrozenLake), tabelnya kecil (16 baris).

Tapi bagaimana kalau main Atari Breakout? Setiap pixel di layar adalah state. Kombinasinya triliunan! Tabelnya tidak akan muat di memori komputer manapun.

Solusi: Ganti Tabel dengan Otak Buatan (Neural Network). Kita tidak perlu menyimpan nilai setiap state. Kita ajarkan NN untuk memperkirakan (Approximation) nilai $Q$ dari gambar layar yang dia lihat.

2. Arsitektur DQN

Input: Gambar Game (State)
Convolutional Neural Network (CNN)
Dense Layers
Output: Q-Values (Kiri, Kanan, Tembak)

3. Kunci Stabilisasi

Melatih NN dengan data RL (yang berubah-ubah terus) itu susah dan tidak stabil. DeepMind (2013) menemukan 2 trik:

  1. Experience Replay: Jangan langsung belajar dari apa yang baru terjadi. Simpan dulu pengalaman (State, Action, Reward, Next State) ke dalam memori ("Replay Buffer"). Lalu saat belajar, ambil sampel acak dari memori ini. Ini memutus korelasi antar data.
  2. Target Network: Gunakan dua otak.
    • Otak Utama (Policy Net): Belajar terus setiap langkah.
    • Otak Target (Target Net): Diam saja, update cuma setiap 1000 langkah. Ini dipakai untuk menghitung target masa depan supaya "tiang gawangnya tidak geser-geser terus".

4. Pseudocode Implementasi (PyTorch)

import torch
import torch.nn as nn
import random

# 1. Definisikan Otak
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

# 2. Training Loop
# ... (Setup Environment & Replay Buffer) ...

def optimize_model():
    if len(memory) < BATCH_SIZE: return
    
    # Ambil sampel acak (Experience Replay)
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Hitung Q saat ini (Policy Net)
    state_batch = torch.cat(batch.state)
    q_values = policy_net(state_batch).gather(1, action_batch)

    # Hitung Target Q (Target Net - Stabil)
    with torch.no_grad():
        next_q_values = target_net(next_state_batch).max(1)[0]
        expected_q_values = (next_q_values * GAMMA) + reward_batch

    # Hitung Loss (Huber Loss / MSE)
    loss = criterion(q_values, expected_q_values.unsqueeze(1))

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()