Skip to content

Commit 1b70252

Browse files
committed
remove labml_helpers dependency: replace Module with nn.Module
1 parent 90e21b5 commit 1b70252

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+209
-301
lines changed

labml_nn/activations/fta/experiment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from labml import experiment
2727
from labml.configs import option
28-
from labml_helpers.module import Module
2928
from labml_nn.activations.fta import FTA
3029
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
3130
from labml_nn.transformers import MultiHeadAttention, TransformerLayer
@@ -65,7 +64,7 @@ def forward(self, x: torch.Tensor):
6564
return self.layer2(x)
6665

6766

68-
class AutoregressiveTransformer(Module):
67+
class AutoregressiveTransformer(nn.Module):
6968
"""
7069
## Auto-Regressive model
7170

labml_nn/activations/swish.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22
from torch import nn
33

4-
from labml_helpers.module import Module
54

65

7-
class Swish(Module):
6+
class Swish(nn.Module):
87
def __init__(self):
98
super().__init__()
109
self.sigmoid = nn.Sigmoid()

labml_nn/adaptive_computation/ponder_net/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@
6565
import torch
6666
from torch import nn
6767

68-
from labml_helpers.module import Module
6968

7069

71-
class ParityPonderGRU(Module):
70+
class ParityPonderGRU(nn.Module):
7271
"""
7372
## PonderNet with GRU for Parity Task
7473
@@ -177,7 +176,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te
177176
return torch.stack(p), torch.stack(y), p_m, y_m
178177

179178

180-
class ReconstructionLoss(Module):
179+
class ReconstructionLoss(nn.Module):
181180
"""
182181
## Reconstruction loss
183182
@@ -213,7 +212,7 @@ def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
213212
return total_loss
214213

215214

216-
class RegularizationLoss(Module):
215+
class RegularizationLoss(nn.Module):
217216
"""
218217
## Regularization loss
219218

labml_nn/capsule_networks/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
import torch.nn.functional as F
3434
import torch.utils.data
3535

36-
from labml_helpers.module import Module
3736

38-
39-
class Squash(Module):
37+
class Squash(nn.Module):
4038
"""
4139
## Squash
4240
@@ -70,7 +68,7 @@ def forward(self, s: torch.Tensor):
7068
return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
7169

7270

73-
class Router(Module):
71+
class Router(nn.Module):
7472
"""
7573
## Routing Algorithm
7674
@@ -133,7 +131,7 @@ def forward(self, u: torch.Tensor):
133131
return v
134132

135133

136-
class MarginLoss(Module):
134+
class MarginLoss(nn.Module):
137135
"""
138136
## Margin loss for class existence
139137
@@ -153,6 +151,7 @@ class MarginLoss(Module):
153151
The $\lambda$ down-weighting is used to stop the length of all capsules from
154152
falling during the initial phase of training.
155153
"""
154+
156155
def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
157156
super().__init__()
158157

labml_nn/capsule_networks/mnist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from labml.configs import option
2222
from labml_helpers.datasets.mnist import MNISTConfigs
2323
from labml_helpers.metrics.accuracy import AccuracyDirect
24-
from labml_helpers.module import Module
2524
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
2625
from labml_nn.capsule_networks import Squash, Router, MarginLoss
2726

2827

29-
class MNISTCapsuleNetworkModel(Module):
28+
class MNISTCapsuleNetworkModel(nn.Module):
3029
"""
3130
## Model for classifying MNIST digits
3231
"""

labml_nn/conv_mixer/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@
3636
import torch
3737
from torch import nn
3838

39-
from labml_helpers.module import Module
4039
from labml_nn.utils import clone_module_list
4140

4241

43-
class ConvMixerLayer(Module):
42+
class ConvMixerLayer(nn.Module):
4443
"""
4544
<a id="ConvMixerLayer"></a>
4645
@@ -96,7 +95,7 @@ def forward(self, x: torch.Tensor):
9695
return x
9796

9897

99-
class PatchEmbeddings(Module):
98+
class PatchEmbeddings(nn.Module):
10099
"""
101100
<a id="PatchEmbeddings"></a>
102101
@@ -136,7 +135,7 @@ def forward(self, x: torch.Tensor):
136135
return x
137136

138137

139-
class ClassificationHead(Module):
138+
class ClassificationHead(nn.Module):
140139
"""
141140
<a id="ClassificationHead"></a>
142141
@@ -169,7 +168,7 @@ def forward(self, x: torch.Tensor):
169168
return x
170169

171170

172-
class ConvMixer(Module):
171+
class ConvMixer(nn.Module):
173172
"""
174173
## ConvMixer
175174

labml_nn/diffusion/ddpm/unet.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
import torch
2828
from torch import nn
2929

30-
from labml_helpers.module import Module
3130

32-
33-
class Swish(Module):
31+
class Swish(nn.Module):
3432
"""
3533
### Swish activation function
3634
@@ -83,7 +81,7 @@ def forward(self, t: torch.Tensor):
8381
return emb
8482

8583

86-
class ResidualBlock(Module):
84+
class ResidualBlock(nn.Module):
8785
"""
8886
### Residual block
8987
@@ -140,7 +138,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
140138
return h + self.shortcut(x)
141139

142140

143-
class AttentionBlock(Module):
141+
class AttentionBlock(nn.Module):
144142
"""
145143
### Attention block
146144
@@ -208,7 +206,7 @@ def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
208206
return res
209207

210208

211-
class DownBlock(Module):
209+
class DownBlock(nn.Module):
212210
"""
213211
### Down block
214212
@@ -229,7 +227,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
229227
return x
230228

231229

232-
class UpBlock(Module):
230+
class UpBlock(nn.Module):
233231
"""
234232
### Up block
235233
@@ -252,7 +250,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
252250
return x
253251

254252

255-
class MiddleBlock(Module):
253+
class MiddleBlock(nn.Module):
256254
"""
257255
### Middle block
258256
@@ -305,7 +303,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
305303
return self.conv(x)
306304

307305

308-
class UNet(Module):
306+
class UNet(nn.Module):
309307
"""
310308
## U-Net
311309
"""

labml_nn/experiments/cifar10.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from labml import lab
1515
from labml.configs import option
1616
from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17-
from labml_helpers.module import Module
1817
from labml_nn.experiments.mnist import MNISTConfigs
1918

2019

@@ -67,7 +66,7 @@ def cifar10_valid_no_augment():
6766
]))
6867

6968

70-
class CIFAR10VGGModel(Module):
69+
class CIFAR10VGGModel(nn.Module):
7170
"""
7271
### VGG model for CIFAR-10 classification
7372
"""

labml_nn/experiments/mnist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch.nn as nn
1212
import torch.utils.data
13-
from labml_helpers.module import Module
1413

1514
from labml import tracker
1615
from labml.configs import option
@@ -34,7 +33,7 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
3433
device: torch.device = DeviceConfigs()
3534

3635
# Classification model
37-
model: Module
36+
model: nn.Module
3837
# Number of epochs to train for
3938
epochs: int = 10
4039

labml_nn/experiments/nlp_autoregression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
2121
from labml_helpers.device import DeviceConfigs
2222
from labml_helpers.metrics.accuracy import Accuracy
23-
from labml_helpers.module import Module
2423
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
2524
from labml_nn.optimizers.configs import OptimizerConfigs
2625

2726

28-
class CrossEntropyLoss(Module):
27+
class CrossEntropyLoss(nn.Module):
2928
"""
3029
### Cross entropy loss
3130
"""
@@ -54,7 +53,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
5453
device: torch.device = DeviceConfigs()
5554

5655
# Autoregressive model
57-
model: Module
56+
model: nn.Module
5857
# Text dataset
5958
text: TextDataset
6059
# Batch size

0 commit comments

Comments
 (0)