Skip to content

Commit 9c32c43

Browse files
committed
f
1 parent 6294b31 commit 9c32c43

File tree

3 files changed

+280
-216
lines changed

3 files changed

+280
-216
lines changed
Lines changed: 157 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# MLX LoRA Fine-tuning Optimization Configuration
22
# Target: Real LoRA fine-tuning efficiency improvements while maintaining convergence
33

4-
max_iterations: 50 # More iterations for breakthrough discoveries
4+
max_iterations: 50
55
checkpoint_interval: 5
66
log_level: "INFO"
77

@@ -12,187 +12,229 @@ llm:
1212
secondary_model: "gemini-2.5-pro-preview-06-05"
1313
secondary_model_weight: 0.3
1414
api_base: "https://generativelanguage.googleapis.com/v1beta/openai/"
15-
temperature: 0.9 # Higher creativity for breakthrough optimizations
15+
temperature: 0.9
1616
top_p: 0.95
1717
max_tokens: 32000
1818
timeout: 600
1919

2020
# Detailed prompt for LoRA optimization
2121
prompt:
2222
system_message: |
23-
You are optimizing MLX LoRA fine-tuning implementations to achieve the same training loss
24-
as standard LoRA but with improved memory efficiency and/or training speed.
23+
You are optimizing MLX LoRA fine-tuning kernels to achieve the same training loss
24+
as standard MLX-LM but with improved memory efficiency and/or training speed.
2525
2626
# 🎯 GOAL: Efficient LoRA Fine-tuning with Maintained Convergence
27-
Your target is to achieve the SAME training loss as baseline LoRA implementations
27+
Your target is to achieve the SAME training loss as baseline MLX-LM implementations
2828
while providing 10%+ improvements in memory usage and/or training speed.
2929
30-
# 🔧 KEY OPTIMIZATION OPPORTUNITIES
30+
# 📋 CURRENT IMPLEMENTATION STRUCTURE
3131
32-
**1. LoRA Weight Pre-computation** ⭐ HIGH SUCCESS PROBABILITY
32+
The code has an `evolved_lora_kernels()` function that returns a dictionary with these kernels:
3333
```python
34-
# Standard: 3 separate matrix multiplications per forward pass
35-
base_out = x @ base_weight.T
36-
lora_a_out = x @ lora_a.T
37-
lora_b_out = lora_a_out @ lora_b.T
38-
result = base_out + scale * lora_b_out
39-
40-
# Target: Pre-compute combined weights when beneficial
41-
if not self.training: # During inference
42-
fused_weight = base_weight + scale * (lora_b @ lora_a)
43-
result = x @ fused_weight.T
34+
return {
35+
"optimized_lora_linear_class": OptimizedLoRALinear,
36+
"optimized_lora_matmul": optimized_lora_matmul,
37+
"optimized_lora_forward_pass": optimized_lora_forward_pass,
38+
"optimized_gradient_computation": optimized_gradient_computation,
39+
"optimized_parameter_update": optimized_parameter_update,
40+
"memory_efficient_loss_computation": memory_efficient_loss_computation,
41+
}
4442
```
4543
46-
**2. Memory-Efficient Gradient Computation**
47-
```python
48-
# Standard: Separate gradient computations
49-
grad_base = grad_output @ x.T
50-
grad_lora_b = grad_output @ lora_a_out.T
51-
grad_lora_a = lora_b.T @ grad_output @ x.T
44+
These kernels get injected via `patch_model_with_kernels()` and used during training.
5245
53-
# Target: Fused gradient computation to reduce memory allocations
54-
# Reuse intermediate tensors, optimize memory access patterns
55-
```
46+
# 🔧 KEY OPTIMIZATION TARGETS IN EVOLVE-BLOCK
5647
57-
**3. Training Loop Optimization**
48+
**1. OptimizedLoRALinear Class** ⭐ HIGH IMPACT
5849
```python
59-
# Standard: Separate forward, loss, backward, update steps
60-
logits = model(inputs)
61-
loss = loss_fn(logits, targets)
62-
grads = compute_gradients(loss)
63-
optimizer.update(model, grads)
64-
65-
# Target: Reduce kernel launches and memory overhead
66-
# Optimize for LoRA-specific gradient patterns
50+
class OptimizedLoRALinear(nn.Module):
51+
def __call__(self, x):
52+
base_out = self.base_layer(x)
53+
# CURRENT: Standard LoRA computation
54+
lora_out = mx.matmul(mx.matmul(x, self.lora_a.T), self.lora_b.T)
55+
return base_out + self.scale * lora_out
56+
57+
# EVOLUTION TARGETS:
58+
# - Fuse base + LoRA computation
59+
# - Pre-compute weights during inference
60+
# - Optimize memory access patterns
61+
# - Use mx.compile for hot paths
6762
```
6863
69-
**4. Multi-Layer LoRA Batch Processing**
64+
**2. optimized_lora_matmul Function** ⚡ SPEED TARGET
7065
```python
71-
# Standard: Apply LoRA to layers one by one
72-
for layer in layers:
73-
layer.q_proj = LoRALinear.from_linear(layer.q_proj)
74-
layer.v_proj = LoRALinear.from_linear(layer.v_proj)
75-
76-
# Target: Batch LoRA operations across layers
77-
# Share computation, optimize memory utilization
66+
@mx.compile
67+
def optimized_lora_matmul(x, lora_a, lora_b, scale):
68+
# CURRENT: Basic compiled matrix multiplication
69+
temp = mx.matmul(x, lora_a.T)
70+
result = mx.matmul(temp, lora_b.T)
71+
return scale * result
72+
73+
# EVOLUTION TARGETS:
74+
# - Fuse matrix operations
75+
# - Optimize for specific tensor shapes
76+
# - Reduce intermediate allocations
77+
# - Vectorize computations
7878
```
7979
80-
**5. Memory-Efficient Loss Computation**
80+
**3. optimized_lora_forward_pass Function** 🚀 INTEGRATION TARGET
8181
```python
82-
# Standard: Full vocabulary materialization
83-
loss = cross_entropy(logits, targets) # Memory: O(batch * seq * vocab)
84-
85-
# Target: Chunked or online loss computation for large vocabularies
86-
# Reduce memory footprint during loss calculation
82+
def optimized_lora_forward_pass(model, x, use_kernels=True):
83+
# CURRENT: Iterates through model layers
84+
for name, layer in model.named_modules():
85+
if hasattr(layer, 'lora_a') and hasattr(layer, 'lora_b'):
86+
# Apply optimized LoRA computation
87+
88+
# EVOLUTION TARGETS:
89+
# - Batch multiple LoRA layers
90+
# - Fuse activations with LoRA
91+
# - Optimize layer traversal
92+
# - Reduce function call overhead
8793
```
8894
89-
**6. UNSLOTH-STYLE MLX KERNEL FUSION** 🎯 PRIMARY SPEED TARGET
95+
**4. memory_efficient_loss_computation Function** 💾 MEMORY TARGET
9096
```python
91-
# Standard: Separate operations
92-
x = mx.add(input, lora_out)
93-
x = activation_fn(x)
94-
x = mx.matmul(x, next_weight)
95-
96-
# Target: Fused kernels using MLX primitives
97-
# Combine LoRA, activation, and next operation
98-
# Leverage mx.compile and mx.eval strategically
97+
def memory_efficient_loss_computation(logits, targets, chunk_size=1024):
98+
# CURRENT: Chunked loss for large vocabularies
99+
if logits.shape[-1] <= chunk_size:
100+
return nn.losses.cross_entropy(logits, targets, reduction="mean")
101+
# Process in chunks...
102+
103+
# EVOLUTION TARGETS:
104+
# - Optimize chunk size dynamically
105+
# - Reduce memory allocations
106+
# - Parallelize chunk processing
107+
# - Smart caching strategies
99108
```
100109
101-
**7. Smart Gradient Accumulation**
110+
**5. optimized_gradient_computation Function** 🧠 GRADIENT TARGET
102111
```python
103-
# Standard: Individual gradient updates
104-
for batch in batches:
105-
loss = forward(batch)
106-
grads = backward(loss)
107-
optimizer.update(grads)
108-
109-
# Target: Accumulated updates with reduced sync points
110-
# Batch multiple LoRA layer updates together
112+
def optimized_gradient_computation(loss, model, use_kernels=True):
113+
# CURRENT: Basic compiled gradient computation
114+
compiled_grad_fn = mx.compile(mx.grad(grad_fn))
115+
return compiled_grad_fn(model)
116+
117+
# EVOLUTION TARGETS:
118+
# - LoRA-specific gradient patterns
119+
# - Accumulate gradients efficiently
120+
# - Reduce gradient computation overhead
121+
# - Smart gradient sharing
111122
```
112123
113-
# 🚀 UNSLOTH-INSPIRED OPTIMIZATION TECHNIQUES (Target 2x+ Speed Improvements)
124+
**6. optimized_parameter_update Function** 🔄 UPDATE TARGET
125+
```python
126+
@mx.compile
127+
def optimized_parameter_update(params, grads, lr):
128+
# CURRENT: Basic parameter update loop
129+
for key in params:
130+
if key in grads:
131+
updated_params[key] = params[key] - lr * grads[key]
132+
133+
# EVOLUTION TARGETS:
134+
# - Batch parameter updates
135+
# - Vectorize updates
136+
# - Optimize for LoRA structure
137+
# - Reduce synchronization points
138+
```
114139
115-
**🔥 Flash Attention Equivalents for MLX**: Fused attention computation patterns
116-
**⚡ Kernel Fusion**: Combine LoRA operations with activation functions
117-
**🧠 Smart Gradient Accumulation**: Batch gradient updates efficiently
118-
**⭐ Optimized MLX Operations**: Leverage mx.fast for critical paths
119-
**🚀 Parameter-Efficient Updates**: Minimize optimizer state overhead
120-
**💾 Memory Mapping**: Efficient tensor reuse and allocation patterns
121-
**🎯 Selective Computation**: Skip unnecessary ops based on LoRA rank/scale
122-
**🔧 Mixed Precision**: Smart FP16/FP32 usage for speed without loss
140+
# 🚀 PROVEN MLX OPTIMIZATION TECHNIQUES
123141
124-
Current baseline shows 1.57x memory improvement but only 1.01x speed.
125-
FOCUS: Discover speed optimizations like unsloth's 2-5x improvements!
142+
**🔥 mx.compile Usage**: Leverage @mx.compile for hot computation paths
143+
**⚡ Tensor Fusion**: Combine multiple operations into single kernels
144+
**🧠 Memory Reuse**: Optimize tensor allocation and reuse patterns
145+
**⭐ Vectorization**: Use MLX's SIMD capabilities effectively
146+
**🚀 Batch Operations**: Process multiple items simultaneously
147+
**💾 Smart Caching**: Cache computed values when beneficial
148+
**🎯 Shape Optimization**: Optimize for common tensor shapes
149+
**🔧 Pipeline Efficiency**: Reduce data movement and sync points
126150
127151
# 📊 SUCCESS METRICS
128152
129153
**Primary Metric**: Training Loss Convergence (MUST MATCH BASELINE ±1%)
130-
- Target: Same final loss as standard LoRA implementation
154+
- Target: Same final loss as standard MLX-LM LoRA implementation
131155
- Critical: Maintain numerical stability and gradient flow
132156
133157
**Secondary Metrics**: Efficiency Improvements
134158
- Memory efficiency: 10%+ reduction in peak memory usage
135159
- Training speed: 10%+ improvement in tokens/second
160+
- Time efficiency: 10%+ reduction in training time
136161
- Ideal: Both memory AND speed improvements
137162
138-
# 🎖️ REAL-WORLD LORA OPTIMIZATION PATTERNS
163+
# 🎖️ REALISTIC OPTIMIZATION EXPECTATIONS
139164
140165
Successful LoRA optimizations typically achieve:
141-
- **Memory reduction**: 15-30% through weight fusion and gradient optimization
142-
- **Speed improvement**: 10-25% through reduced kernel launches and better memory access
143-
- **Maintained convergence**: Critical for practical adoption
166+
- **Memory reduction**: 10-30% through smart tensor management
167+
- **Speed improvement**: 15-50% through kernel fusion and compilation
168+
- **Maintained convergence**: Essential for practical adoption
144169
145-
Your optimizations should target similar patterns adapted for MLX.
170+
Your optimizations should target these realistic improvements for MLX.
146171
147172
# 🚫 CONSTRAINTS
148-
- Keep exact function signatures from initial_program.py
149-
- Maintain numerical correctness (loss must match baseline within 0.01)
173+
- Keep exact function signatures and return values
174+
- Maintain numerical correctness (loss must match baseline within 1%)
150175
- Support all LoRA configs (ranks 8-64, any scale/dropout)
151176
- MLX-only dependencies (mx.core, mx.nn, mx.optimizers)
152-
- 🚨 CRITICAL: Concise evolution changes (under 35,000 chars total)
153-
- NO verbose comments - focus on algorithmic improvements
154-
- Prioritize SPEED over memory (we already have 1.57x memory gain)
155-
- Test mx.compile, mx.eval, kernel fusion, gradient accumulation patterns
156-
157-
# 🔍 WHAT TO EVOLVE - TARGET UNSLOTH-STYLE 2x+ SPEED GAINS
158-
159-
Focus on `evolved_lora_kernels` function. Prioritize SPEED optimizations:
160-
161-
1. **optimized_lora_fine_tuning**: Main training pipeline with kernel fusion
162-
2. **optimized_training_loop**: Batch gradient accumulation like unsloth
163-
3. **optimized_train_step**: Fused forward/backward with mx.compile
164-
4. **optimized_linear_to_lora_layers**: Batched multi-layer LoRA application
165-
5. **optimized_evaluate**: Fast inference with weight pre-computation
166-
167-
🎯 PRIMARY TARGETS FOR SPEED BREAKTHROUGH:
168-
- Leverage `mx.compile()` for hot paths (like unsloth's kernel compilation)
169-
- Use `mx.eval()` strategically to minimize sync points
170-
- Batch operations across multiple LoRA layers simultaneously
171-
- Pre-compute weights when beneficial (inference mode optimization)
172-
- Implement gradient accumulation patterns that reduce memory allocations
173-
174-
Current Results: 1.57x memory ✅, 1.01x speed ❌
175-
Target: Discover 2-5x speed improvements while maintaining perfect convergence!
177+
- 🚨 CRITICAL: Concise evolution changes (under 30,000 chars total)
178+
- Focus on algorithmic improvements, not verbose comments
179+
- Ensure kernels can be properly patched into models
180+
- Test optimizations work with real MLX-LM training
181+
182+
# 🔍 WHAT TO EVOLVE - FOCUS ON EVOLVE-BLOCK
183+
184+
**Primary Evolution Target: `evolved_lora_kernels()` function**
185+
186+
The EVOLVE-BLOCK contains 6 kernels that get injected into MLX-LM training:
187+
188+
1. **OptimizedLoRALinear**: The core LoRA layer implementation
189+
2. **optimized_lora_matmul**: Compiled matrix multiplication kernel
190+
3. **optimized_lora_forward_pass**: Model forward pass optimization
191+
4. **optimized_gradient_computation**: Gradient computation optimization
192+
5. **optimized_parameter_update**: Parameter update optimization
193+
6. **memory_efficient_loss_computation**: Loss computation optimization
194+
195+
🎯 **PRIMARY OPTIMIZATION STRATEGIES:**
196+
- Add more @mx.compile decorators for hot paths
197+
- Fuse multiple operations into single kernels
198+
- Optimize memory access patterns and reuse
199+
- Batch operations across multiple LoRA layers
200+
- Pre-compute values when beneficial (inference optimization)
201+
- Implement LoRA-specific optimizations based on mathematical properties
202+
- Reduce intermediate tensor allocations
203+
- Optimize for common LoRA configurations (rank 8-64)
204+
205+
🔬 **CURRENT STATUS:** Starting from basic working implementations
206+
**TARGET:** Achieve 15-25% efficiency improvements while maintaining convergence
207+
208+
# ⚠️ CRITICAL EVOLUTION GUIDELINES
209+
210+
1. **ALWAYS preserve function signatures** - the patching system depends on them
211+
2. **Test numerical correctness** - loss must converge to same value as baseline
212+
3. **Use MLX primitives effectively** - leverage mx.compile, mx.eval, etc.
213+
4. **Focus on realistic optimizations** - don't over-engineer
214+
5. **Maintain code clarity** - optimizations should be understandable
215+
6. **Ensure kernel injection works** - test that patches apply correctly
216+
217+
**Evolution Success = Same Loss + Better Performance + Working Integration**
176218
177219
num_top_programs: 6
178220
num_diverse_programs: 4
179221

180222
# Database configuration for LoRA optimization
181223
database:
182224
db_path: "./openevolve_output/program_db"
183-
population_size: 80 # Larger population for more diverse explorations
225+
population_size: 80
184226
archive_size: 40
185227
num_islands: 4
186-
elite_selection_ratio: 0.20 # Less elite pressure, more exploration
187-
exploitation_ratio: 0.6 # Balanced exploration for breakthroughs
228+
elite_selection_ratio: 0.20
229+
exploitation_ratio: 0.6
188230
exploration_ratio: 0.4
189231

190232
# Evaluator configuration
191233
evaluator:
192-
timeout: 1200 # Longer timeout for real LoRA training
234+
timeout: 1200
193235
parallel_evaluations: 1
194236

195237
# Evolution settings
196238
diff_based_evolution: true
197239
allow_full_rewrites: false
198-
max_code_length: 45000 # Encourage concise, focused optimizations
240+
max_code_length: 45000

0 commit comments

Comments
 (0)