Implementing AlphaZero
View on GitHubWhat We're Covering
Introduction
AlphaZero, DeepMind's groundbreaking reinforcement learning algorithm, demonstrated that a single algorithm could achieve superhuman performance in chess, shogi, and Go through self-play alone. I've been working on implementing AlphaZero from scratch to understand the deep principles behind Monte Carlo Tree Search (MCTS) and deep reinforcement learning. This post documents my journey and key insights.
Core Components
AlphaZero combines three main components in a beautifully elegant way:
- Deep Neural Network that outputs both policy (move probabilities) and value (position evaluation)
- Monte Carlo Tree Search (MCTS) that uses the neural network to guide exploration
- Self-Play Training where the algorithm learns by playing against itself
Neural Network Architecture
The neural network takes a board representation as input and outputs both a policy vector (probability distribution over moves) and a scalar value (expected outcome). AlphaZero uses a ResNet-based architecture consisting of a convolution block, a series of residual blocks, and two heads (policy and value).
The initial convolution block lifts the raw board representation into a higher-dimensional feature space, allowing the network to extract a rich set of local patterns from the input. The subsequent series of residual blocks repeatedly refines this representation. The channel dimension remains constant, reflecting the idea that each block makes incremental improvements to the same underlying feature space rather than re-encoding it from scratch.
The policy head projects the shared representation into a vector of move logits, producing a probability distribution over legal actions. The value head projects the same representation into a single scalar, estimating the expected game outcome from the current position.
class ResidualBlock(nn.Module):
"""
Input/Output: (B, C, 8, 8)
"""
def __init__(self, channels: int = 256):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + identity
out = self.relu(out)
return out
class ConvBlock(nn.Module):
"""
(B, 18, 8, 8) -> (B, C, 8, 8)
"""
def __init__(self, in_channels: int = 18, channels: int = 64):
super().__init__()
self.conv = nn.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class PolicyHead(nn.Module):
"""
(B, C, 8, 8) -> (B, n_moves)
"""
def __init__(self, channels: int = 64, n_moves: int = 4672):
super().__init__()
self.conv = nn.Conv2d(channels, 2, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(2)
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(2 * 8 * 8, n_moves)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = x.view(x.size(0), -1) # (B, channels)
logits = self.fc(x) # (B, n_moves)
return logits
class ValueHead(nn.Module):
"""
(B, C, 8, 8) -> (B, 1) where each value is in [-1, 1]
"""
def __init__(self, channels: int = 64, hidden: int = 256):
super().__init__()
self.conv = nn.Conv2d(channels, 1, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(1 * 8 * 8, hidden)
self.fc2 = nn.Linear(hidden, 1)
self.tanh = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = x.view(x.size(0), -1) # (B, channels)
x = self.fc1(x) # (B, hidden)
x = self.relu(x)
x = self.fc2(x) # (B, 1)
v = self.tanh(x) # (B, 1) in [-1, 1]
return v
class AlphaZeroChessNet(nn.Module):
"""
Full model:
Input: (B, 18, 8, 8)
Output: policy logits - (B, n_moves), value - (B, 1)
"""
def __init__(self, channels: int = 64, n_blocks: int = 19, n_moves: int = 4672, value_hidden: int = 256):
super().__init__()
self.stem = ConvBlock(in_channels=18, channels=channels)
self.residuals = nn.ModuleList([ResidualBlock(channels) for _ in range(n_blocks)])
self.policy_head = PolicyHead(channels=channels, n_moves=n_moves)
self.value_head = ValueHead(channels=channels, hidden=value_hidden)
def forward(self, x: torch.Tensor):
x = self.stem(x)
for block in self.residuals:
x = block(x)
policy_logits = self.policy_head(x)
value = self.value_head(x)
return policy_logits, value
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
Monte Carlo Tree Search
MCTS is the search algorithm used to find the best move. It works by repeatedly simulating games while gradually expanding the search tree. Moves that lead to better outcomes are visited more often, and the final decision is made by selecting the most visited move.
The core of MCTS consists of three phases: Selection, Expansion, and Backpropagation.
Selection: Navigate down the most promising path within the tree until reaching a leaf node. This phase does not create new nodes, but appends the leaf node to a list. We pick the child based on the PUCT formula: Q + U, where Q = W / N (W is the total value accumulated from all simulations that have visited this node, and N is the number of visits), and U = c_puct * P * sqrt(parent_visit_count) / (1 + child_visit_count), where P is the prior probability from the neural network.
Expansion: Expand the path of the nodes selected. Each node will have k children, where k is the number of legal moves from the current position.
Backpropagation: Use the result of the playout to update information in the nodes that were part of the path taken. If the leaf is a terminal position (game over), we update directly with +1, -1, or 0 (draw). If the leaf is not terminal, the value is obtained from the neural network. When training the neural network, we know the winner of the game, so the model is not trained on its own predictions.
class MCTSNode:
"""
Class to represent a node in the MCTS tree
"""
def __init__(self, board: chess.Board, parent: Optional["MCTSNode"]=None, parent_move: Optional[chess.Move]=None):
self.board = board
self.parent = parent
self.parent_move = parent_move
self.children: Dict[chess.Move, MCTSNode] = {}
self.is_expanded = False
# Node statistics
self.N = 0 # visit count of this node
self.W = 0.0 # total value
self.Q = 0.0 # self.W / self.N
# Edge priors P(s,a): stored in parent.P[move]
self.P: Dict[chess.Move, float] = {} # only meaningful for nodes with children
class MCTS:
"""
Class to perform MCTS search
"""
def __init__(self, c_puct: float = 1.5, n_simulations: int = 200, batch_size: int = 1,
inference_queue=None, inference_result_queue=None, worker_id: int = 0):
self.c_puct = c_puct
self.n_simulations = n_simulations
self.batch_size = batch_size
self.move_indexer = AlphaZeroMoveIndexer()
self.inference_queue = inference_queue
self.response_queue = inference_result_queue # dedicated Queue per worker
self.worker_id = worker_id
self.pending_inferences = {} # request_id -> (pending_leaves, batch_tensor)
def search(self, root_board: chess.Board) -> Tuple[Optional[chess.Move], Dict[chess.Move, float]]:
root = MCTSNode(root_board.copy(stack=False))
pending_leaves = [] # List of (leaf, path) tuples for non-terminal nodes
for sim in range(self.n_simulations):
leaf, path = self._select(root)
if leaf.board.is_game_over():
value = self._get_terminal_value(leaf.board)
self._backpropagate(path, value)
continue
pending_leaves.append((leaf, path))
# Evaluate batch when full or at end of simulations
if len(pending_leaves) >= self.batch_size or sim == self.n_simulations - 1:
if pending_leaves:
self._batch_evaluate_and_backpropagate(pending_leaves)
pending_leaves = []
if not root.children:
legal_moves = list(root_board.legal_moves)
move = random.choice(legal_moves) if legal_moves else None
return move, {}
move_visits = {move: child.N for move, child in root.children.items()}
total_visits = sum(move_visits.values())
if total_visits == 0:
legal_moves = list(root_board.legal_moves)
move = random.choice(legal_moves) if legal_moves else None
return move, {}
move_probs = {m: n / total_visits for m, n in move_visits.items()}
# Sample move based on probability distribution instead of always picking the best
moves = list(move_probs.keys())
probs = list(move_probs.values())
sampled_move = random.choices(moves, weights=probs, k=1)[0]
return sampled_move, move_probs
def _select(self, root: MCTSNode) -> Tuple[MCTSNode, List[MCTSNode]]:
"""
Returns a leaf node and the path from root to leaf.
"""
node = root
path = [node]
while node.is_expanded and (not node.board.is_game_over()):
if not node.children:
break
_, node = self._select_child(node)
path.append(node)
return node, path
def _select_child(self, node: MCTSNode) -> Tuple[chess.Move, MCTSNode]:
assert node.children, "Tried to select a child from a node with no children."
N_sum = sum(child.N for child in node.children.values()) + 1e-8 # add small epsilon to avoid division by zero
best_score = -float("inf")
best_move: Optional[chess.Move] = None
best_child: Optional[MCTSNode] = None
for move, child in node.children.items():
Q = child.Q
assert node.P, "Node must have P values set (should be expanded)"
P = node.P[move] # prior probability of the move from the neural net
U = self.c_puct * P * (math.sqrt(N_sum) / (1.0 + child.N))
score = Q + U
if score > best_score:
best_score = score
best_move = move
best_child = child
# Should always be set given non-empty children
assert best_move is not None and best_child is not None
return best_move, best_child
def _get_terminal_value(self, board: chess.Board) -> float:
"""Get value for terminal board position."""
result = board.result() # e.g: "1-0", "0-1", "1/2-1/2"
if result == "1-0":
return 1.0 if board.turn == chess.WHITE else -1.0
if result == "0-1":
return 1.0 if board.turn == chess.BLACK else -1.0
return 0.0
def _batch_evaluate_and_backpropagate(self, pending_leaves: List[Tuple[MCTSNode, List[MCTSNode]]]) -> None:
"""Evaluate a batch of leaves and backpropagate results."""
if not pending_leaves:
return
boards = [leaf.board for leaf, _ in pending_leaves]
board_tensors = []
for board in boards:
tensor = board_to_input_planes(board)
board_tensors.append(tensor)
# Stack into batch: (batch_size, 18, 8, 8)
batch_tensor = np.stack(board_tensors)
# Send inference request with worker_id for routing response
request_id = str(uuid.uuid4())
self.inference_queue.put((self.worker_id, batch_tensor, request_id))
# Block on dedicated response queue (no polling), each process has it's unique queue
response_request_id, policy_logits_batch, value_batch = self.response_queue.get()
assert response_request_id == request_id, f"Request ID mismatch: {response_request_id} != {request_id}"
for i, (leaf, path) in enumerate(pending_leaves):
board = leaf.board
policy_logits = policy_logits_batch[i] # (4672,)
value = float(value_batch[i].squeeze().item())
legal_moves = list(board.legal_moves)
legal_moves_dict = {}
for move in legal_moves:
move_idx = self.move_indexer.encode(move) # encode each move to a number
assert move_idx is not None, "Move index is None"
logit_value = float(policy_logits[move_idx].item()) # get the policy logit for the move
legal_moves_dict[move] = logit_value
# Expand node
leaf.is_expanded = True
leaf.P = {}
leaf.children = {}
# Build children for all legal moves
for move in board.legal_moves:
p = float(legal_moves_dict.get(move, 0.0))
leaf.P[move] = p
child_board = board.copy(stack=False)
child_board.push(move)
leaf.children[move] = MCTSNode(child_board, parent=leaf, parent_move=move)
# Renormalize priors over legal moves, fallback to uniform if all 0
s = sum(leaf.P.values())
if s > 0.0:
inv = 1.0 / s
for m in leaf.P:
leaf.P[m] *= inv
else:
legal = list(leaf.children.keys())
if legal:
uniform = 1.0 / len(legal)
for m in legal:
leaf.P[m] = uniform
# Add Dirichlet noise to root node only (for exploration)
if leaf.parent is None:
dirichlet_alpha = 0.3 # Typical value from AlphaZero
epsilon = 0.25 # Mixing parameter
legal_moves_list = list(leaf.P.keys())
num_moves = len(legal_moves_list)
if num_moves > 0:
# Sample Dirichlet noise
dirichlet_noise = np.random.dirichlet([dirichlet_alpha] * num_moves)
# Mix priors with noise
for i, move in enumerate(legal_moves_list):
leaf.P[move] = (1 - epsilon) * leaf.P[move] + epsilon * dirichlet_noise[i]
# Renormalize after adding noise
s = sum(leaf.P.values())
if s > 0.0:
inv = 1.0 / s
for m in leaf.P:
leaf.P[m] *= inv
self._backpropagate(path, value)
def _backpropagate(self, path: List[MCTSNode], value: float) -> None:
for node in reversed(path):
node.N += 1
node.W += value
node.Q = node.W / node.N
value = -value # flip perspective when moving up one ply
Self-Play Training Loop
The core idea of self-play is to iteratively improve a neural network by having it play games against itself. At the start of training, the network is randomly initialized. Over time, as it generates experience through self-play and learns from it, the network progressively improves its playing strength. The self-play training loop consists of three key components: 1. Data Generation (Self-Play): Games are generated by running Monte Carlo Tree Search (MCTS), where the neural network provides policy and value estimates to guide the search. When the search requires a neural network evaluation, the request is enqueued and the self-play worker suspends execution until the inference result is returned. 2. Model Inference: A dedicated inference process batches and evaluates queued requests using the neural network, returning policy and value predictions to the MCTS workers. 3. Model Training: The data generated from self-play—consisting of game states, MCTS-derived policy targets, and game outcomes—is used to train the neural network. The network is optimized to approximate the improved policy and value estimates produced by MCTS. These three components run in separate processes and communicate asynchronously via queues, allowing self-play, inference, and training to proceed in parallel.
def sample_move_with_temperature(probs: dict, temperature: float):
"""Sample a move from MCTS probability distribution with temperature scaling.
Args:
probs: dict mapping moves to visit count probabilities
temperature: temperature for sampling (1.0 = proportional to probs,
0 = deterministic/argmax, >1 = more uniform)
Returns:
Selected move
"""
moves = list(probs.keys())
visit_probs = np.array([probs[m] for m in moves], dtype=np.float64)
if temperature == 0:
# Deterministic: pick highest probability move
return moves[np.argmax(visit_probs)]
# Apply temperature: p_i^(1/T) / sum(p_j^(1/T))
scaled = visit_probs ** (1.0 / temperature)
scaled /= scaled.sum()
return moves[np.random.choice(len(moves), p=scaled)]
def mcts_worker_self_play(worker_id, inference_queue, response_queues, result_queue, c_puct, n_simulations):
"""Worker process that generates board positions via self play"""
response_queue = response_queues[worker_id]
mcts = MCTS(c_puct=c_puct, n_simulations=n_simulations, batch_size=64,
inference_queue=inference_queue, inference_result_queue=response_queue, worker_id=worker_id)
while True:
try:
board = chess.Board()
game_history = []
move_count = 0
while not board.is_game_over():
_, probs = mcts.search(board)
# Temperature scaling: high temp early for exploration, low temp later for strength
# First 30 plies (~15 moves): temperature=1.0 for diverse openings
# After that: temperature=0.1 for near-deterministic play
if move_count < 30:
temperature = 1.5
else:
temperature = 0.1
move = sample_move_with_temperature(probs, temperature)
game_history.append((board.copy(), move, probs))
board.push(move)
move_count += 1
result = board.result()
if result == "1-0":
final_winner = 1.0 # White won
elif result == "0-1":
final_winner = -1.0 # Black won
else:
final_winner = 0.0 # Draw
game_data = [
(board_pos, move, probs,
final_winner if board_pos.turn == chess.WHITE else -final_winner)
for board_pos, move, probs in game_history
]
while game_data:
remaining = []
for item in game_data:
try:
result_queue.put_nowait(item)
except queue.Full:
remaining.append(item)
if remaining:
game_data = remaining
time.sleep(0.01)
else:
break
except queue.Full:
time.sleep(0.1)
except KeyboardInterrupt:
break
except Exception as e:
print(f"Worker error: {e}")
time.sleep(0.1)
@torch.no_grad()
def inference_worker(inference_queue, response_queues, weight_update_queue, model_state_dict, device_str, max_batch_size=512, max_wait_ms=20):
"""Single GPU worker that batches and processes inference requests for maximum GPU utilization"""
device = torch.device(device_str)
model = AlphaZeroChessNet(channels=256, n_blocks=20, n_moves=4672).to(device)
model.load_state_dict(model_state_dict)
model.eval()
while True:
try:
try:
new_state_dict = weight_update_queue.get_nowait()
model.load_state_dict(new_state_dict)
except queue.Empty:
pass
pending_requests = [] # List of (worker_id, batch_tensor, request_id, batch_size)
total_samples = 0
# Get first request (blocking)
try:
worker_id, batch_tensor, request_id = inference_queue.get(timeout=0.1)
batch_size = batch_tensor.shape[0]
pending_requests.append((worker_id, batch_tensor, request_id, batch_size))
total_samples += batch_size
except queue.Empty:
continue
# Collect more requests (non-blocking) to fill up the batch
deadline = time.time() + max_wait_ms / 1000.0
while total_samples < max_batch_size and time.time() < deadline:
try:
worker_id, batch_tensor, request_id = inference_queue.get_nowait()
batch_size = batch_tensor.shape[0]
pending_requests.append((worker_id, batch_tensor, request_id, batch_size))
total_samples += batch_size
except queue.Empty:
break
if not pending_requests:
continue
all_tensors = [req[1] for req in pending_requests]
mega_batch = np.concatenate(all_tensors, axis=0)
mega_batch_tensor = torch.from_numpy(mega_batch).to(device)
policy_logits, value_logits = model(mega_batch_tensor)
# Convert 3-class value logits to expected value for MCTS
value = model.value_head.logits_to_expected_value(value_logits)
# Convert to CPU numpy for IPC
policy_logits_cpu = policy_logits.cpu()
value_cpu = value.cpu()
offset = 0
# extract the correct batch and send back to the worker
for worker_id, _, request_id, batch_size in pending_requests:
policy_slice = policy_logits_cpu[offset:offset + batch_size]
value_slice = value_cpu[offset:offset + batch_size]
offset += batch_size
response_queues[worker_id].put((request_id, policy_slice, value_slice))
except KeyboardInterrupt:
break
except Exception as e:
print(f"Inference worker error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1)
CHECKPOINT_STEPS = 500
def train_batch(batch_buffer, model, optimizer, device, move_indexer, step):
"""Train on a batch of MCTS results"""
if len(batch_buffer) == 0:
return None
boards, _, probs_list, winners = zip(*batch_buffer)
board_tensors = [board_to_input_planes(board) for board in boards]
curr_batch = np.stack(board_tensors)
curr_batch = torch.from_numpy(curr_batch).to(device)
with autocast(device_type=device.type):
policy_logits, value_logits = model(curr_batch) # (batch_size, 4672), (batch_size, 3)
# Convert MCTS probability distributions to target tensors
policy_targets = []
for prob_dict in probs_list:
# Create target distribution: (4672,) tensor with probabilities
target = torch.zeros(4672, dtype=torch.float32, device=device)
for move, prob in prob_dict.items():
move_idx = move_indexer.encode(move)
if move_idx is not None and 0 <= move_idx < 4672:
target[move_idx] = prob
policy_targets.append(target)
policy_targets = torch.stack(policy_targets) # (batch_size, 4672)
# Policy loss: KL divergence (better for probability distributions)
policy_loss = F.kl_div(
F.log_softmax(policy_logits, dim=1),
policy_targets,
reduction='batchmean'
)
# Value loss: Soft cross-entropy with smoothed targets (anti-collapse)
# Instead of one-hot, use soft targets to keep gradients alive:
# Win → [0.05, 0.10, 0.85]
# Draw → [0.15, 0.70, 0.15]
# Loss → [0.85, 0.10, 0.05]
soft_targets_map = {
1.0: torch.tensor([0.05, 0.10, 0.85], device=device), # Win
0.0: torch.tensor([0.15, 0.70, 0.15], device=device), # Draw
-1.0: torch.tensor([0.85, 0.10, 0.05], device=device), # Loss
}
value_soft_targets = torch.stack([soft_targets_map[w] for w in winners]) # (B, 3)
# Soft cross-entropy: -sum(target * log_softmax(logits))
# With sample weighting: decisive games weighted 3x higher
log_probs = F.log_softmax(value_logits.float(), dim=1)
sample_weights = torch.tensor([3.0 if w != 0.0 else 1.0 for w in winners], device=device)
value_loss = -(value_soft_targets * log_probs).sum(dim=1) # (B,)
value_loss = (value_loss * sample_weights).mean() / sample_weights.mean() # Weighted mean
# Combined loss
loss = policy_loss + value_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
# Compute value prediction stats for monitoring
with torch.no_grad():
value_probs = F.softmax(value_logits.float(), dim=1)
pred_loss_pct = (value_probs[:, 0].mean() * 100).item()
pred_draw_pct = (value_probs[:, 1].mean() * 100).item()
pred_win_pct = (value_probs[:, 2].mean() * 100).item()
print(f"Step {step}: Policy loss: {policy_loss.item():.4f}, Value loss: {value_loss.item():.4f}, Total loss: {loss.item():.4f} | Value preds: L={pred_loss_pct:.1f}% D={pred_draw_pct:.1f}% W={pred_win_pct:.1f}%", flush=True)
if step % CHECKPOINT_STEPS == 0:
checkpoint_path = f"checkpoint_{step}.pt"
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path}", flush=True)
return model_state_dict
def simple_train_supervised(batch_size: int = 256, num_workers: int = 4, checkpoint_path: str = "checkpoint_8000.pt"):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlphaZeroChessNet(channels=256, n_blocks=20, n_moves=4672).to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
if os.path.exists(checkpoint_path):
print(f"Loading checkpoint from {checkpoint_path}...", flush=True)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
print(f"✓ Successfully loaded {checkpoint_path}", flush=True)
move_indexer = AlphaZeroMoveIndexer()
# Create shared queues
result_queue = Queue(maxsize=10000)
weight_update_queue = Queue(maxsize=num_workers * 2)
inference_weight_update_queue = Queue(maxsize=2)
inference_queue = Queue(maxsize=10000)
# Create per-worker response queues
response_queues = [Queue(maxsize=100) for _ in range(num_workers)]
# Start single inference worker (GPU)
# Convert to CPU before sending through multiprocessing
model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
inference_worker_process = Process(target=inference_worker,
args=(inference_queue, response_queues, inference_weight_update_queue, model_state_dict, str(device)))
inference_worker_process.start()
# Start MCTS worker processes (CPU)
workers = []
for worker_id in range(num_workers):
p = Process(target=mcts_worker_self_play,
args=(worker_id, inference_queue, response_queues, result_queue,
1.5, 200))
p.start()
workers.append(p)
batch_buffer = []
step = 0
try:
while True:
try:
board, move, probs, winner = result_queue.get(timeout=0.1)
batch_buffer.append((board, move, probs, winner))
if len(batch_buffer) >= batch_size:
step += 1
updated_state_dict = train_batch(batch_buffer, model, optimizer, device, move_indexer, step)
if updated_state_dict is not None:
try:
inference_weight_update_queue.put_nowait(updated_state_dict)
except queue.Full:
pass
for _ in range(num_workers):
try:
weight_update_queue.put_nowait(updated_state_dict)
except queue.Full:
pass # Skip if queue full (workers will get next update)
batch_buffer = []
except queue.Empty:
if len(batch_buffer) > 0:
step += 1
updated_state_dict = train_batch(batch_buffer, model, optimizer, device, move_indexer, step)
if updated_state_dict is not None:
try:
inference_weight_update_queue.put_nowait(updated_state_dict)
except queue.Full:
pass
for _ in range(num_workers):
try:
weight_update_queue.put_nowait(updated_state_dict)
except queue.Full:
pass # Skip if queue full (workers will get next update)
batch_buffer = []
except KeyboardInterrupt:
print("Shutting down...")
inference_worker_process.terminate()
inference_worker_process.join()
for p in workers:
p.terminate()
p.join()
print("Workers terminated")
Key Challenges
Implementing AlphaZero from scratch revealed several interesting challenges:
1. MCTS Parallelization and Decoupling GPU from Self-Play
Running thousands of MCTS simulations per move is computationally expensive. I implemented parallel simulations using multiprocessing, where each worker runs independent MCTS trees and results are aggregated. A significant portion of development time was spent fully decoupling the GPU from the self-play process.
Initially, the MCTS process was calling inference directly on the GPU every time it required a neural network to get the prior. This bottlenecked the GPU as it was context switching frequently, leading to low throughput. To address this, I created a separate inference queue and a dedicated inference worker process that batches and evaluates requests. This fully decoupled the GPU from the self-play process, so only two processes use the GPU: the inference worker and the overall training process.
2. Value Loss Collapse
Value loss collapse occurred particularly when most games were drawn, causing the model to blindly predict a draw for all games. This led to an incredibly low value loss from the start, preventing the value head from learning effectively during training.
To combat this, I made numerous adjustments to both the model architecture and training process. I tried temperature scaling, changed the value head to a 3-class head instead of outputting a single scalar value, and implemented a weighted loss function where decisive games (win/loss) were weighted 3x higher than drawn games. I also added Dirichlet noise at the root to encourage exploration and prevent draws.
Results and Insights
Surprisingly, coding up the model was relatively simple and straightforward. Contrary to popular belief, there was really no magic—the hardest part of the project was the engineering: (1) increasing GPU throughput/utilization and (2) addressing value loss collapse.
In the first stages of the project, I adapted more of an AlphaGo approach, where the model would be trained on human games. This was definitely a more stable approach, especially because it avoided the value loss collapse issue. However, the exclusive self-play approach quickly caught up in performance. This may be because the games were largely varied in ratings; having a model trained on higher quality games or Stockfish would likely have better distillation. But it was surprising to see how quickly the self-play approach caught up.