Skip to content

Commit 93336c1

Browse files
committed
Update config.yaml
1 parent b99dfea commit 93336c1

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,45 @@ prompt:
3232
❌ `grads.astype()` when grads is a dict - Only works on mx.array
3333
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
3434
❌ `mlx.utils.tree_*` functions - These don't exist
35+
❌ `model.update_parameters()` - MLX models don't have this method
36+
❌ `float(loss_tuple)` - Loss might be tuple, extract properly
37+
❌ `batch[:, :-1]` on 1D arrays - Check array dimensions first
38+
❌ Assuming tensor shapes without verification
39+
40+
**CRITICAL MLX VALUE AND SHAPE HANDLING:**
41+
42+
🚨 **Loss Value Extraction:**
43+
```python
44+
# WRONG: float(loss_value) when loss_value might be tuple
45+
# CORRECT: Handle MLX loss properly
46+
if isinstance(loss_value, tuple):
47+
loss_scalar = float(loss_value[0]) # Extract first element
48+
elif isinstance(loss_value, mx.array):
49+
loss_scalar = float(mx.eval(loss_value)) # Evaluate and convert
50+
else:
51+
loss_scalar = float(loss_value)
52+
```
53+
54+
🚨 **Array Indexing Safety:**
55+
```python
56+
# WRONG: batch[:, :-1] without checking dimensions
57+
# CORRECT: Check shape before indexing
58+
if batch.ndim >= 2:
59+
inputs = batch[:, :-1]
60+
targets = batch[:, 1:]
61+
else:
62+
# Handle 1D case or reshape
63+
inputs = batch[:-1]
64+
targets = batch[1:]
65+
```
66+
67+
🚨 **Model Parameter Updates:**
68+
```python
69+
# WRONG: model.update_parameters(new_params)
70+
# CORRECT: Use optimizer.update()
71+
optimizer.update(model, grads)
72+
mx.eval(model.parameters(), optimizer.state)
73+
```
3574
❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
3675
❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
3776
❌ Assuming `mx.eval()` always returns arrays - Can return None

0 commit comments

Comments
 (0)