@@ -32,6 +32,45 @@ prompt:
32
32
❌ `grads.astype()` when grads is a dict - Only works on mx.array
33
33
❌ Any JAX/PyTorch tree utilities - MLX doesn't have these
34
34
❌ `mlx.utils.tree_*` functions - These don't exist
35
+ ❌ `model.update_parameters()` - MLX models don't have this method
36
+ ❌ `float(loss_tuple)` - Loss might be tuple, extract properly
37
+ ❌ `batch[:, :-1]` on 1D arrays - Check array dimensions first
38
+ ❌ Assuming tensor shapes without verification
39
+
40
+ **CRITICAL MLX VALUE AND SHAPE HANDLING:**
41
+
42
+ 🚨 **Loss Value Extraction:**
43
+ ```python
44
+ # WRONG: float(loss_value) when loss_value might be tuple
45
+ # CORRECT: Handle MLX loss properly
46
+ if isinstance(loss_value, tuple):
47
+ loss_scalar = float(loss_value[0]) # Extract first element
48
+ elif isinstance(loss_value, mx.array):
49
+ loss_scalar = float(mx.eval(loss_value)) # Evaluate and convert
50
+ else:
51
+ loss_scalar = float(loss_value)
52
+ ```
53
+
54
+ 🚨 **Array Indexing Safety:**
55
+ ```python
56
+ # WRONG: batch[:, :-1] without checking dimensions
57
+ # CORRECT: Check shape before indexing
58
+ if batch.ndim >= 2:
59
+ inputs = batch[:, :-1]
60
+ targets = batch[:, 1:]
61
+ else:
62
+ # Handle 1D case or reshape
63
+ inputs = batch[:-1]
64
+ targets = batch[1:]
65
+ ```
66
+
67
+ 🚨 **Model Parameter Updates:**
68
+ ```python
69
+ # WRONG: model.update_parameters(new_params)
70
+ # CORRECT: Use optimizer.update()
71
+ optimizer.update(model, grads)
72
+ mx.eval(model.parameters(), optimizer.state)
73
+ ```
35
74
❌ `mx.value_and_grad(fn, has_aux=True)` - has_aux parameter does NOT exist in MLX
36
75
❌ `mx.value_and_grad(fn, **kwargs)` - No keyword arguments supported except argnums/argnames
37
76
❌ Assuming `mx.eval()` always returns arrays - Can return None
0 commit comments