By: PrintableKanjiEmblem
Times Read: 69
Likes: 0 Dislikes: 0
Topic: Reference
Overview
Large‑Language Models (LLMs) are transformer‑based neural nets trained on terabytes of text to learn a contextual distribution over tokens.
The training pipeline is a blend of data engineering, deep‑learning research, and large‑scale distributed systems.
Below is a “full stack” walk‑through aimed at someone who’s built distributed software and written low‑level code.
1. Data Pipeline
Stage | What you do | Why it matters |
---|---|---|
Corpus collection | Web‑scraping, public corpora (Common Crawl, Wikipedia, books, code), proprietary data. | Determines the model’s inductive biases and knowledge. |
Deduplication / filtering | Remove exact or near‑duplicates, profanity filters, harmful content flags. | Prevents data leakage, reduces redundancy, keeps the dataset clean. |
Sharding | Split the corpus into shards (~1 GB each). | Enables parallel ingestion, deterministic random access. |
Tokenization | Byte‑Pair Encoding (BPE), SentencePiece, or WordPiece. | Turns raw text into integer IDs; the choice influences vocabulary size and OOV handling. |
Pre‑tokenization transforms | Lowercasing, whitespace normalisation, special token insertion ( , ). | Standardises the input and helps the model learn boundaries. |
Dataset construction | Convert shards to a dataset that supports streaming: TextDataset , IterableDataset . | Allows reading data without keeping it all in memory. |
Sequence packing | For transformer training, break each shard into fixed‑length token sequences (e.g. 2048 tokens). | Enables efficient batching and padding‑free computation. |
Bucketing / collate | Group similar‑length sequences together. | Minimises padding, improves GPU utilisation. |
Implementation hint: Use torch.utils.data.IterableDataset
with a generator that reads the shards sequentially, then yields mini‑batches to a collate function that pads to the max length in the batch.
2. Model Architecture
Sub‑module | Purpose | Typical design choices |
---|---|---|
Embedding layer | Turns token IDs into dense vectors. |
|
Positional encoding | Adds order information. | Learned positional embeddings or sinusoidal embeddings. |
Transformer blocks | Core transformer: self‑attention + MLP + residuals. | 12–96 layers, 768–12288 hidden dim, 12–128 heads. |
LayerNorm | Stabilises training. | Pre‑ or post‑LayerNorm depending on the variant. |
Output head | Predict next‑token distribution. | Linear layer tied to embedding (weight = embedding.weight ). |
Key equations
- Scaled Dot‑Product AttentionAttention(Q,K,V)=softmax (QK⊤dk)VAttention(Q,K,V)=softmax(dkQK⊤)V
- Multi‑Head AttentionMH(Q,K,V)=Concat(head1,…,headh)WOMH(Q,K,V)=Concat(head1,…,headh)WO
- Position‑wise Feed‑ForwardFFN(x)=max(0,xW1+b1)W2+b2FFN(x)=max(0,xW1+b1)W2+b2
3. Training Objective & Loss
Objective | Loss | Why it is chosen |
---|---|---|
Autoregressive LM | Cross‑entropy over next‑token | Enables free‑form text generation. |
Masked LM (BERT style) | Cross‑entropy over masked tokens | Enables bidirectional context; used for fine‑tuning. |
For GPT‑style models the loss is computed over every token in the sequence:
L=1N∑i=1NCE(pθ(xi+1∣x1:i),one-hot(xi+1))L=N1i=1∑NCE(pθ(xi+1∣x1:i),one-hot(xi+1))
Regularisation tricks
Trick | Effect |
---|---|
Label smoothing (e.g., 0.1) | Reduces over‑confidence, improves calibration. |
Weight decay (AdamW) | Encourages smaller weights, improves generalisation. |
Dropout | 0.1–0.2 on attention & FFN; mitigates overfitting. |
Stochastic Depth (DropPath) | Randomly skips entire layers during training. |
4. Optimisation
Component | Detail | Why it matters |
---|---|---|
Optimizer | AdamW (Adam with decoupled weight decay). | Proven to converge faster on transformers. |
Learning‑rate schedule | Cosine decay with linear warm‑up (e.g., 3 k steps). | Stabilises early training, avoids catastrophic updates. |
Batch‑size | Effective batch = gradient_accumulation_steps × local_batch × num_devices . | Larger batches improve gradient estimate but need more memory. |
Gradient clipping | Clip by norm (e.g., 1.0). | Prevents exploding gradients in long sequences. |
Mixed‑precision | FP16 + loss‑scale or BF16 on TPU. | Cuts memory, speeds up inference/training, keeps loss stable. |
Example hyper‑parameters for a 175B GPT‑3‑style run:
Param | Value |
---|---|
| 12288 |
| 96 |
| 96 |
| 2048 |
| 50257 |
| 8 × 2048 tokens |
| 16 |
| 0.00025 |
| 0.1 |
| 250k |
| 300k |
5. Distributed Training & Hardware
5.1 Parallelism Strategies
Strategy | What it splits | Typical use‑case |
---|---|---|
Data Parallelism (DDP) | Input batches | Simple, scales up to ~128 GPUs |
Tensor Parallelism | Weight matrices inside a layer | Needed for >1 TB models (e.g., GPT‑3, PaLM) |
Pipeline Parallelism | Stages of the model across devices | Helps when model size > GPU memory but < 1 TB |
Sharded Optimizer State | Optimiser parameters (AdamW) | Saves GPU memory when using large batch sizes |
Off‑loading (CPU, NVMe) | Activations, gradients | Allows scaling beyond GPU memory limits |
Popular libraries: DeepSpeed (ZeRO‑3), Megatron‑LLaMA (tensor + pipeline parallelism), NVIDIA Megatron‑LM, GLaM, etc.
5.2 Hardware Landscape
Hardware | Memory per core | Throughput | Notes |
---|---|---|---|
NVIDIA A100 80 GB | 80 GB | ~200 TFLOP/s FP16 | Standard for 70–175 B models |
NVIDIA H100 80 GB | 80 GB | ~350 TFLOP/s FP8/FP16 | Higher bandwidth, supports FP8 training |
Google TPU v4 | 32 GB | ~400 TFLOP/s | Native BF16, 5 TB‑scale training |
AMD Instinct MI250X | 128 GB | ~200 TFLOP/s FP16 | Good for pipeline parallelism |
Memory budgeting:
total_mem ≈ 2×(emb + transformer + activations) + optimizer
A typical rule‑of‑thumb is to keep activations ~ 40 % of memory. Activation checkpointing (aka recomputation) trades compute for memory.
6. Training Workflow (Pseudo‑code)
# Simplified training loop using DeepSpeed ZeRO-3
import deepspeed
import torch
from transformers import GPT2Config, GPT2LMHeadModel
config = GPT2Config(
vocab_size=50257,
n_embd=12288,
n_layer=96,
n_head=96,
bos_token_id=50256,
eos_token_id=50256,
)
model = GPT2LMHeadModel(config).to('cuda')
optimizer, model, _, _ = deepspeed.initialize(
args=deepspeed_args,
model=model,
model_parameters=model.parameters(),
config=deepspeed_config,
)
data_loader = build_dataset(...)
for epoch in range(num_epochs):
for step, batch in enumerate(data_loader):
inputs = batch['input_ids'].to('cuda')
labels = batch['labels'].to('cuda')
outputs = model(inputs, labels=labels)
loss = outputs.loss / deepspeed_args.gradient_accumulation_steps
model.backward(loss)
optimizer.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f'Epoch {epoch} Step {step} Loss {loss.item()}')
Key points:
- DeepSpeed ZeRO‑3 automatically shards optimizer state & gradients.
- Gradient accumulation lets you simulate a huge batch on a few GPUs.
- Automatic mixed‑precision (AMP) is enabled by default.
7. Post‑Training
7.1 Fine‑tuning & Adaptation
Approach | When to use | Typical pipeline |
---|---|---|
LoRA (Low‑Rank Adaptation) | Adapting to a new domain with few parameters | Freeze base, add rank‑r adapters in each attention/feed‑forward; train only adapters. |
Prefix Tuning | Prompt‑engineering for few‑shot tasks | Add learnable prefix tokens to the attention queries. |
P-tuning | Fine‑tune the prompt embeddings | Insert a small learnable token sequence. |
Full‑model fine‑tune | Large domain shift or supervised task | Train all weights; requires more compute. |
7.2 RLHF (Reinforcement Learning from Human Feedback)
- Reward model – supervised regression on human‑ranked responses.
- Proximal Policy Optimization – fine‑tune language model to maximise reward.
- Safety constraints – policy‑based or rejection sampling.
8. Evaluation & Metrics
Metric | What it tells you |
---|---|
Perplexity | Log‑likelihood of validation data. Lower is better. |
BLEU / ROUGE | N‑gram overlap for summarisation / translation tasks. |
Zero‑shot benchmarks (SuperGLUE, MMLU, etc.) | General reasoning ability. |
Calibration error | Probability estimates vs. empirical correctness. |
Bias & toxicity tests | Detecting harmful outputs. |
Tip: For large models, perplexity alone is insufficient; use a suite of downstream tasks and human evaluation.
9. Common Pitfalls & Gotchas
Pitfall | Fix |
---|---|
Memory fragmentation | Use contiguous tensors, avoid inplace ops that break autograd graph. |
Gradient explosion on long contexts | Gradient clipping, use RMSprop‑like optimisers (AdamW is fine). |
Over‑fitting to short‑text datasets | Mix in longer documents, use longer context windows during training. |
Data leakage (public model leaks training data) | Deduplication, privacy‑preserving filtering. |
Inadequate hyper‑parameter tuning | Automate with Optuna or Ray Tune; monitor learning‑rate dynamics. |
10. Quick Reference: Tooling
Tool | Purpose | Language |
---|---|---|
TensorFlow / PyTorch | Core DL frameworks | Python |
DeepSpeed / ZeRO | Model & optimizer sharding | Python |
Megatron‑LM | Large‑scale transformer training | Python |
JAX / Flax | TPU‑native training | Python |
NVMe‑Offload | Gradient/activation off‑load | C/C++ |
Dask / Ray | Distributed data prep | Python |
MLflow / Weights & Biases | Experiment tracking | Python |
Bottom line
Training an LLM is essentially training a gigantic, multi‑layer transformer on a massive, clean, tokenised corpus, using distributed, memory‑efficient parallelism, and a carefully tuned optimisation loop. Once you master the data pipeline and the distributed system, the rest is largely a matter of hyper‑parameter sweeps and scaling experiments.
Happy training! 🚀