Skip to content

Commit b99dfea

Browse files
committed
Update config.yaml
1 parent f430c1a commit b99dfea

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ prompt:
3232
❌ `grads.astype()` when grads is a dict - Only works on mx.array
3333
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
3434
❌ `mlx.utils.tree_*` functions - These don't exist
35+
❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
36+
❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
3537
❌ Assuming `mx.eval()` always returns arrays - Can return None
3638
❌ Modulo operations without checking for zero divisors
3739
❌ Assuming trainer attributes exist without checking
@@ -68,6 +70,35 @@ prompt:
6870
return grads
6971
```
7072
73+
✅ **Value and Grad Operations:**
74+
```python
75+
# CORRECT: Simple value_and_grad usage
76+
loss_value, grads = mx.value_and_grad(loss_fn)(model)
77+
78+
# CORRECT: If you need multiple return values from loss_fn, handle separately
79+
def loss_fn(model):
80+
logits = model(inputs)
81+
loss = nn.losses.cross_entropy(logits, targets)
82+
# Return only the loss (not a tuple with aux data)
83+
return loss
84+
85+
loss_value, grads = mx.value_and_grad(loss_fn)(model)
86+
87+
# WRONG: mx.value_and_grad(loss_fn, has_aux=True)(model) # has_aux not supported
88+
# WRONG: (loss, aux), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
89+
90+
# CORRECT: If you need auxiliary data, compute it separately
91+
def loss_fn(model):
92+
logits = model(inputs)
93+
loss = nn.losses.cross_entropy(logits, targets)
94+
return loss
95+
96+
loss_value, grads = mx.value_and_grad(loss_fn)(model)
97+
# Compute auxiliary data separately if needed
98+
logits = model(inputs) # Recompute for aux data
99+
accuracy = compute_accuracy(logits, targets)
100+
```
101+
71102
✅ **Memory Management:**
72103
```python
73104
# Use mx.eval() to materialize computations
@@ -150,6 +181,8 @@ prompt:
150181
8. ✓ Check object attributes exist before accessing
151182
9. ✓ Handle None and empty arrays gracefully
152183
10. ✓ Use safe fallbacks for all operations
184+
11. ✓ mx.value_and_grad() used without has_aux parameter
185+
12. ✓ Loss functions return single values, not tuples
153186
154187
**PRIMARY GOAL: Discover memory-efficient patterns that enable faster, lower-memory fine-tuning on Mac hardware**
155188
@@ -204,6 +237,24 @@ prompt:
204237
actual_memory = process.memory_info().rss / 1024 / 1024
205238
```
206239
240+
❌ **value_and_grad() incompatible function arguments**
241+
```python
242+
# WRONG: Using JAX-style has_aux parameter
243+
(scaled_loss_val, unscaled_loss_val), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
244+
245+
# RIGHT: MLX only supports simple value_and_grad
246+
loss_value, grads = mx.value_and_grad(loss_fn)(model)
247+
248+
# If you need scaled loss, handle it in the loss function itself:
249+
def loss_fn(model):
250+
logits = model(inputs)
251+
loss = nn.losses.cross_entropy(logits, targets)
252+
# Scale inside the function if needed
253+
return loss / max(total_accumulation_steps, 1)
254+
255+
loss_value, grads = mx.value_and_grad(loss_fn)(model)
256+
```
257+
207258
❌ **'NoneType' object is not subscriptable**
208259
```python
209260
# WRONG: loss_value = mx.eval(loss)[0] # mx.eval() might return None

0 commit comments

Comments
 (0)