1
1
# MLX LoRA Fine-tuning Optimization Configuration
2
2
# Target: Real LoRA fine-tuning efficiency improvements while maintaining convergence
3
3
4
- max_iterations : 50 # More iterations for breakthrough discoveries
4
+ max_iterations : 50
5
5
checkpoint_interval : 5
6
6
log_level : " INFO"
7
7
@@ -12,187 +12,229 @@ llm:
12
12
secondary_model : " gemini-2.5-pro-preview-06-05"
13
13
secondary_model_weight : 0.3
14
14
api_base : " https://generativelanguage.googleapis.com/v1beta/openai/"
15
- temperature : 0.9 # Higher creativity for breakthrough optimizations
15
+ temperature : 0.9
16
16
top_p : 0.95
17
17
max_tokens : 32000
18
18
timeout : 600
19
19
20
20
# Detailed prompt for LoRA optimization
21
21
prompt :
22
22
system_message : |
23
- You are optimizing MLX LoRA fine-tuning implementations to achieve the same training loss
24
- as standard LoRA but with improved memory efficiency and/or training speed.
23
+ You are optimizing MLX LoRA fine-tuning kernels to achieve the same training loss
24
+ as standard MLX-LM but with improved memory efficiency and/or training speed.
25
25
26
26
# 🎯 GOAL: Efficient LoRA Fine-tuning with Maintained Convergence
27
- Your target is to achieve the SAME training loss as baseline LoRA implementations
27
+ Your target is to achieve the SAME training loss as baseline MLX-LM implementations
28
28
while providing 10%+ improvements in memory usage and/or training speed.
29
29
30
- # 🔧 KEY OPTIMIZATION OPPORTUNITIES
30
+ # 📋 CURRENT IMPLEMENTATION STRUCTURE
31
31
32
- **1. LoRA Weight Pre-computation** ⭐ HIGH SUCCESS PROBABILITY
32
+ The code has an `evolved_lora_kernels()` function that returns a dictionary with these kernels:
33
33
```python
34
- # Standard: 3 separate matrix multiplications per forward pass
35
- base_out = x @ base_weight.T
36
- lora_a_out = x @ lora_a.T
37
- lora_b_out = lora_a_out @ lora_b.T
38
- result = base_out + scale * lora_b_out
39
-
40
- # Target: Pre-compute combined weights when beneficial
41
- if not self.training: # During inference
42
- fused_weight = base_weight + scale * (lora_b @ lora_a)
43
- result = x @ fused_weight.T
34
+ return {
35
+ "optimized_lora_linear_class": OptimizedLoRALinear,
36
+ "optimized_lora_matmul": optimized_lora_matmul,
37
+ "optimized_lora_forward_pass": optimized_lora_forward_pass,
38
+ "optimized_gradient_computation": optimized_gradient_computation,
39
+ "optimized_parameter_update": optimized_parameter_update,
40
+ "memory_efficient_loss_computation": memory_efficient_loss_computation,
41
+ }
44
42
```
45
43
46
- **2. Memory-Efficient Gradient Computation**
47
- ```python
48
- # Standard: Separate gradient computations
49
- grad_base = grad_output @ x.T
50
- grad_lora_b = grad_output @ lora_a_out.T
51
- grad_lora_a = lora_b.T @ grad_output @ x.T
44
+ These kernels get injected via `patch_model_with_kernels()` and used during training.
52
45
53
- # Target: Fused gradient computation to reduce memory allocations
54
- # Reuse intermediate tensors, optimize memory access patterns
55
- ```
46
+ # 🔧 KEY OPTIMIZATION TARGETS IN EVOLVE-BLOCK
56
47
57
- **3. Training Loop Optimization**
48
+ **1. OptimizedLoRALinear Class** ⭐ HIGH IMPACT
58
49
```python
59
- # Standard: Separate forward, loss, backward, update steps
60
- logits = model(inputs)
61
- loss = loss_fn(logits, targets)
62
- grads = compute_gradients(loss)
63
- optimizer.update(model, grads)
64
-
65
- # Target: Reduce kernel launches and memory overhead
66
- # Optimize for LoRA-specific gradient patterns
50
+ class OptimizedLoRALinear(nn.Module):
51
+ def __call__(self, x):
52
+ base_out = self.base_layer(x)
53
+ # CURRENT: Standard LoRA computation
54
+ lora_out = mx.matmul(mx.matmul(x, self.lora_a.T), self.lora_b.T)
55
+ return base_out + self.scale * lora_out
56
+
57
+ # EVOLUTION TARGETS:
58
+ # - Fuse base + LoRA computation
59
+ # - Pre-compute weights during inference
60
+ # - Optimize memory access patterns
61
+ # - Use mx.compile for hot paths
67
62
```
68
63
69
- **4. Multi-Layer LoRA Batch Processing**
64
+ **2. optimized_lora_matmul Function** ⚡ SPEED TARGET
70
65
```python
71
- # Standard: Apply LoRA to layers one by one
72
- for layer in layers:
73
- layer.q_proj = LoRALinear.from_linear(layer.q_proj)
74
- layer.v_proj = LoRALinear.from_linear(layer.v_proj)
75
-
76
- # Target: Batch LoRA operations across layers
77
- # Share computation, optimize memory utilization
66
+ @mx.compile
67
+ def optimized_lora_matmul(x, lora_a, lora_b, scale):
68
+ # CURRENT: Basic compiled matrix multiplication
69
+ temp = mx.matmul(x, lora_a.T)
70
+ result = mx.matmul(temp, lora_b.T)
71
+ return scale * result
72
+
73
+ # EVOLUTION TARGETS:
74
+ # - Fuse matrix operations
75
+ # - Optimize for specific tensor shapes
76
+ # - Reduce intermediate allocations
77
+ # - Vectorize computations
78
78
```
79
79
80
- **5. Memory-Efficient Loss Computation**
80
+ **3. optimized_lora_forward_pass Function** 🚀 INTEGRATION TARGET
81
81
```python
82
- # Standard: Full vocabulary materialization
83
- loss = cross_entropy(logits, targets) # Memory: O(batch * seq * vocab)
84
-
85
- # Target: Chunked or online loss computation for large vocabularies
86
- # Reduce memory footprint during loss calculation
82
+ def optimized_lora_forward_pass(model, x, use_kernels=True):
83
+ # CURRENT: Iterates through model layers
84
+ for name, layer in model.named_modules():
85
+ if hasattr(layer, 'lora_a') and hasattr(layer, 'lora_b'):
86
+ # Apply optimized LoRA computation
87
+
88
+ # EVOLUTION TARGETS:
89
+ # - Batch multiple LoRA layers
90
+ # - Fuse activations with LoRA
91
+ # - Optimize layer traversal
92
+ # - Reduce function call overhead
87
93
```
88
94
89
- **6. UNSLOTH-STYLE MLX KERNEL FUSION ** 🎯 PRIMARY SPEED TARGET
95
+ **4. memory_efficient_loss_computation Function ** 💾 MEMORY TARGET
90
96
```python
91
- # Standard: Separate operations
92
- x = mx.add(input, lora_out)
93
- x = activation_fn(x)
94
- x = mx.matmul(x, next_weight)
95
-
96
- # Target: Fused kernels using MLX primitives
97
- # Combine LoRA, activation, and next operation
98
- # Leverage mx.compile and mx.eval strategically
97
+ def memory_efficient_loss_computation(logits, targets, chunk_size=1024):
98
+ # CURRENT: Chunked loss for large vocabularies
99
+ if logits.shape[-1] <= chunk_size:
100
+ return nn.losses.cross_entropy(logits, targets, reduction="mean")
101
+ # Process in chunks...
102
+
103
+ # EVOLUTION TARGETS:
104
+ # - Optimize chunk size dynamically
105
+ # - Reduce memory allocations
106
+ # - Parallelize chunk processing
107
+ # - Smart caching strategies
99
108
```
100
109
101
- **7. Smart Gradient Accumulation **
110
+ **5. optimized_gradient_computation Function ** 🧠 GRADIENT TARGET
102
111
```python
103
- # Standard: Individual gradient updates
104
- for batch in batches:
105
- loss = forward(batch)
106
- grads = backward(loss)
107
- optimizer.update(grads)
108
-
109
- # Target: Accumulated updates with reduced sync points
110
- # Batch multiple LoRA layer updates together
112
+ def optimized_gradient_computation(loss, model, use_kernels=True):
113
+ # CURRENT: Basic compiled gradient computation
114
+ compiled_grad_fn = mx.compile(mx.grad(grad_fn))
115
+ return compiled_grad_fn(model)
116
+
117
+ # EVOLUTION TARGETS:
118
+ # - LoRA-specific gradient patterns
119
+ # - Accumulate gradients efficiently
120
+ # - Reduce gradient computation overhead
121
+ # - Smart gradient sharing
111
122
```
112
123
113
- # 🚀 UNSLOTH-INSPIRED OPTIMIZATION TECHNIQUES (Target 2x+ Speed Improvements)
124
+ **6. optimized_parameter_update Function** 🔄 UPDATE TARGET
125
+ ```python
126
+ @mx.compile
127
+ def optimized_parameter_update(params, grads, lr):
128
+ # CURRENT: Basic parameter update loop
129
+ for key in params:
130
+ if key in grads:
131
+ updated_params[key] = params[key] - lr * grads[key]
132
+
133
+ # EVOLUTION TARGETS:
134
+ # - Batch parameter updates
135
+ # - Vectorize updates
136
+ # - Optimize for LoRA structure
137
+ # - Reduce synchronization points
138
+ ```
114
139
115
- **🔥 Flash Attention Equivalents for MLX**: Fused attention computation patterns
116
- **⚡ Kernel Fusion**: Combine LoRA operations with activation functions
117
- **🧠 Smart Gradient Accumulation**: Batch gradient updates efficiently
118
- **⭐ Optimized MLX Operations**: Leverage mx.fast for critical paths
119
- **🚀 Parameter-Efficient Updates**: Minimize optimizer state overhead
120
- **💾 Memory Mapping**: Efficient tensor reuse and allocation patterns
121
- **🎯 Selective Computation**: Skip unnecessary ops based on LoRA rank/scale
122
- **🔧 Mixed Precision**: Smart FP16/FP32 usage for speed without loss
140
+ # 🚀 PROVEN MLX OPTIMIZATION TECHNIQUES
123
141
124
- Current baseline shows 1.57x memory improvement but only 1.01x speed.
125
- FOCUS: Discover speed optimizations like unsloth's 2-5x improvements!
142
+ **🔥 mx.compile Usage**: Leverage @mx.compile for hot computation paths
143
+ **⚡ Tensor Fusion**: Combine multiple operations into single kernels
144
+ **🧠 Memory Reuse**: Optimize tensor allocation and reuse patterns
145
+ **⭐ Vectorization**: Use MLX's SIMD capabilities effectively
146
+ **🚀 Batch Operations**: Process multiple items simultaneously
147
+ **💾 Smart Caching**: Cache computed values when beneficial
148
+ **🎯 Shape Optimization**: Optimize for common tensor shapes
149
+ **🔧 Pipeline Efficiency**: Reduce data movement and sync points
126
150
127
151
# 📊 SUCCESS METRICS
128
152
129
153
**Primary Metric**: Training Loss Convergence (MUST MATCH BASELINE ±1%)
130
- - Target: Same final loss as standard LoRA implementation
154
+ - Target: Same final loss as standard MLX-LM LoRA implementation
131
155
- Critical: Maintain numerical stability and gradient flow
132
156
133
157
**Secondary Metrics**: Efficiency Improvements
134
158
- Memory efficiency: 10%+ reduction in peak memory usage
135
159
- Training speed: 10%+ improvement in tokens/second
160
+ - Time efficiency: 10%+ reduction in training time
136
161
- Ideal: Both memory AND speed improvements
137
162
138
- # 🎖️ REAL-WORLD LORA OPTIMIZATION PATTERNS
163
+ # 🎖️ REALISTIC OPTIMIZATION EXPECTATIONS
139
164
140
165
Successful LoRA optimizations typically achieve:
141
- - **Memory reduction**: 15 -30% through weight fusion and gradient optimization
142
- - **Speed improvement**: 10-25 % through reduced kernel launches and better memory access
143
- - **Maintained convergence**: Critical for practical adoption
166
+ - **Memory reduction**: 10 -30% through smart tensor management
167
+ - **Speed improvement**: 15-50 % through kernel fusion and compilation
168
+ - **Maintained convergence**: Essential for practical adoption
144
169
145
- Your optimizations should target similar patterns adapted for MLX.
170
+ Your optimizations should target these realistic improvements for MLX.
146
171
147
172
# 🚫 CONSTRAINTS
148
- - Keep exact function signatures from initial_program.py
149
- - Maintain numerical correctness (loss must match baseline within 0.01 )
173
+ - Keep exact function signatures and return values
174
+ - Maintain numerical correctness (loss must match baseline within 1% )
150
175
- Support all LoRA configs (ranks 8-64, any scale/dropout)
151
176
- MLX-only dependencies (mx.core, mx.nn, mx.optimizers)
152
- - 🚨 CRITICAL: Concise evolution changes (under 35,000 chars total)
153
- - NO verbose comments - focus on algorithmic improvements
154
- - Prioritize SPEED over memory (we already have 1.57x memory gain)
155
- - Test mx.compile, mx.eval, kernel fusion, gradient accumulation patterns
156
-
157
- # 🔍 WHAT TO EVOLVE - TARGET UNSLOTH-STYLE 2x+ SPEED GAINS
158
-
159
- Focus on `evolved_lora_kernels` function. Prioritize SPEED optimizations:
160
-
161
- 1. **optimized_lora_fine_tuning**: Main training pipeline with kernel fusion
162
- 2. **optimized_training_loop**: Batch gradient accumulation like unsloth
163
- 3. **optimized_train_step**: Fused forward/backward with mx.compile
164
- 4. **optimized_linear_to_lora_layers**: Batched multi-layer LoRA application
165
- 5. **optimized_evaluate**: Fast inference with weight pre-computation
166
-
167
- 🎯 PRIMARY TARGETS FOR SPEED BREAKTHROUGH:
168
- - Leverage `mx.compile()` for hot paths (like unsloth's kernel compilation)
169
- - Use `mx.eval()` strategically to minimize sync points
170
- - Batch operations across multiple LoRA layers simultaneously
171
- - Pre-compute weights when beneficial (inference mode optimization)
172
- - Implement gradient accumulation patterns that reduce memory allocations
173
-
174
- Current Results: 1.57x memory ✅, 1.01x speed ❌
175
- Target: Discover 2-5x speed improvements while maintaining perfect convergence!
177
+ - 🚨 CRITICAL: Concise evolution changes (under 30,000 chars total)
178
+ - Focus on algorithmic improvements, not verbose comments
179
+ - Ensure kernels can be properly patched into models
180
+ - Test optimizations work with real MLX-LM training
181
+
182
+ # 🔍 WHAT TO EVOLVE - FOCUS ON EVOLVE-BLOCK
183
+
184
+ **Primary Evolution Target: `evolved_lora_kernels()` function**
185
+
186
+ The EVOLVE-BLOCK contains 6 kernels that get injected into MLX-LM training:
187
+
188
+ 1. **OptimizedLoRALinear**: The core LoRA layer implementation
189
+ 2. **optimized_lora_matmul**: Compiled matrix multiplication kernel
190
+ 3. **optimized_lora_forward_pass**: Model forward pass optimization
191
+ 4. **optimized_gradient_computation**: Gradient computation optimization
192
+ 5. **optimized_parameter_update**: Parameter update optimization
193
+ 6. **memory_efficient_loss_computation**: Loss computation optimization
194
+
195
+ 🎯 **PRIMARY OPTIMIZATION STRATEGIES:**
196
+ - Add more @mx.compile decorators for hot paths
197
+ - Fuse multiple operations into single kernels
198
+ - Optimize memory access patterns and reuse
199
+ - Batch operations across multiple LoRA layers
200
+ - Pre-compute values when beneficial (inference optimization)
201
+ - Implement LoRA-specific optimizations based on mathematical properties
202
+ - Reduce intermediate tensor allocations
203
+ - Optimize for common LoRA configurations (rank 8-64)
204
+
205
+ 🔬 **CURRENT STATUS:** Starting from basic working implementations
206
+ **TARGET:** Achieve 15-25% efficiency improvements while maintaining convergence
207
+
208
+ # ⚠️ CRITICAL EVOLUTION GUIDELINES
209
+
210
+ 1. **ALWAYS preserve function signatures** - the patching system depends on them
211
+ 2. **Test numerical correctness** - loss must converge to same value as baseline
212
+ 3. **Use MLX primitives effectively** - leverage mx.compile, mx.eval, etc.
213
+ 4. **Focus on realistic optimizations** - don't over-engineer
214
+ 5. **Maintain code clarity** - optimizations should be understandable
215
+ 6. **Ensure kernel injection works** - test that patches apply correctly
216
+
217
+ **Evolution Success = Same Loss + Better Performance + Working Integration**
176
218
177
219
num_top_programs : 6
178
220
num_diverse_programs : 4
179
221
180
222
# Database configuration for LoRA optimization
181
223
database :
182
224
db_path : " ./openevolve_output/program_db"
183
- population_size : 80 # Larger population for more diverse explorations
225
+ population_size : 80
184
226
archive_size : 40
185
227
num_islands : 4
186
- elite_selection_ratio : 0.20 # Less elite pressure, more exploration
187
- exploitation_ratio : 0.6 # Balanced exploration for breakthroughs
228
+ elite_selection_ratio : 0.20
229
+ exploitation_ratio : 0.6
188
230
exploration_ratio : 0.4
189
231
190
232
# Evaluator configuration
191
233
evaluator :
192
- timeout : 1200 # Longer timeout for real LoRA training
234
+ timeout : 1200
193
235
parallel_evaluations : 1
194
236
195
237
# Evolution settings
196
238
diff_based_evolution : true
197
239
allow_full_rewrites : false
198
- max_code_length : 45000 # Encourage concise, focused optimizations
240
+ max_code_length : 45000
0 commit comments