@@ -32,6 +32,8 @@ prompt:
32
32
❌ `grads.astype()` when grads is a dict - Only works on mx.array
33
33
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
34
34
❌ `mlx.utils.tree_*` functions - These don't exist
35
+ ❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
36
+ ❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
35
37
❌ Assuming `mx.eval()` always returns arrays - Can return None
36
38
❌ Modulo operations without checking for zero divisors
37
39
❌ Assuming trainer attributes exist without checking
@@ -68,6 +70,35 @@ prompt:
68
70
return grads
69
71
```
70
72
73
+ ✅ **Value and Grad Operations:**
74
+ ```python
75
+ # CORRECT: Simple value_and_grad usage
76
+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
77
+
78
+ # CORRECT: If you need multiple return values from loss_fn, handle separately
79
+ def loss_fn(model):
80
+ logits = model(inputs)
81
+ loss = nn.losses.cross_entropy(logits, targets)
82
+ # Return only the loss (not a tuple with aux data)
83
+ return loss
84
+
85
+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
86
+
87
+ # WRONG: mx.value_and_grad(loss_fn, has_aux=True)(model) # has_aux not supported
88
+ # WRONG: (loss, aux), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
89
+
90
+ # CORRECT: If you need auxiliary data, compute it separately
91
+ def loss_fn(model):
92
+ logits = model(inputs)
93
+ loss = nn.losses.cross_entropy(logits, targets)
94
+ return loss
95
+
96
+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
97
+ # Compute auxiliary data separately if needed
98
+ logits = model(inputs) # Recompute for aux data
99
+ accuracy = compute_accuracy(logits, targets)
100
+ ```
101
+
71
102
✅ **Memory Management:**
72
103
```python
73
104
# Use mx.eval() to materialize computations
@@ -150,6 +181,8 @@ prompt:
150
181
8. ✓ Check object attributes exist before accessing
151
182
9. ✓ Handle None and empty arrays gracefully
152
183
10. ✓ Use safe fallbacks for all operations
184
+ 11. ✓ mx.value_and_grad() used without has_aux parameter
185
+ 12. ✓ Loss functions return single values, not tuples
153
186
154
187
**PRIMARY GOAL: Discover memory-efficient patterns that enable faster, lower-memory fine-tuning on Mac hardware**
155
188
@@ -204,6 +237,24 @@ prompt:
204
237
actual_memory = process.memory_info().rss / 1024 / 1024
205
238
```
206
239
240
+ ❌ **value_and_grad() incompatible function arguments**
241
+ ```python
242
+ # WRONG: Using JAX-style has_aux parameter
243
+ (scaled_loss_val, unscaled_loss_val), grads = mx.value_and_grad(loss_fn, has_aux=True)(model)
244
+
245
+ # RIGHT: MLX only supports simple value_and_grad
246
+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
247
+
248
+ # If you need scaled loss, handle it in the loss function itself:
249
+ def loss_fn(model):
250
+ logits = model(inputs)
251
+ loss = nn.losses.cross_entropy(logits, targets)
252
+ # Scale inside the function if needed
253
+ return loss / max(total_accumulation_steps, 1)
254
+
255
+ loss_value, grads = mx.value_and_grad(loss_fn)(model)
256
+ ```
257
+
207
258
❌ **'NoneType' object is not subscriptable**
208
259
```python
209
260
# WRONG: loss_value = mx.eval(loss)[0] # mx.eval() might return None
0 commit comments