Skip to content

Commit ba79095

Browse files
Merge pull request #57 from Jeyhun1/sampling
Adaptive sampling
2 parents 2e89486 + bcfa93d commit ba79095

17 files changed

+1066
-137
lines changed

PINNFramework/Adaptive_Sampler.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import numpy as np
3+
from .Sampler import Sampler
4+
5+
6+
class AdaptiveSampler(Sampler):
7+
def __init__(self, n_seed, model, pde, device = torch.device("cuda")):
8+
"""
9+
Constructor of the AdaptiveSampler class
10+
11+
Args:
12+
n_seed (int): the number of seed points for adaptive sampling.
13+
model: is the model which is trained to represent the underlying PDE.
14+
pde (function): function that represents residual of the PDE.
15+
device (torch.device): "cuda" or "cpu".
16+
"""
17+
self.n_seed = n_seed
18+
self.model = model
19+
self.pde = pde
20+
self.device = device
21+
super(AdaptiveSampler, self).__init__()
22+
23+
def sample(self, lb, ub, n):
24+
"""
25+
Generate a tuple of 'n' sampled points in [lb,ub] and corresponding weights.
26+
27+
Args:
28+
lb (numpy.ndarray): lower bound of the domain.
29+
ub (numpy.ndarray): upper bound of the domain.
30+
n (int): the number of sampled points.
31+
"""
32+
33+
torch.manual_seed(42)
34+
np.random.seed(42)
35+
36+
lb = lb.reshape(1,-1)
37+
ub = ub.reshape(1,-1)
38+
39+
dimension = lb.shape[1]
40+
xs = np.random.uniform(lb, ub, size=(self.n_seed, dimension))
41+
42+
# collocation points
43+
xf = np.random.uniform(lb, ub, size=(n, dimension))
44+
45+
# make the points into tensors
46+
xf = torch.tensor(xf).float().to(self.device)
47+
xs = torch.tensor(xs).float().to(self.device)
48+
49+
# prediction with seed points
50+
xs.requires_grad = True
51+
prediction_seed = self.model(xs)
52+
53+
# pde residual with seed points
54+
loss_seed = self.pde(xs, prediction_seed)
55+
losses_xf = torch.zeros_like(xf)
56+
57+
# Compute the 2-norm distance between seed points and collocation points
58+
dist = torch.cdist(xf, xs, p=2)
59+
60+
# obtain the smallest element of the given tensor
61+
knn = dist.topk(1, largest=False)
62+
63+
# assign the seed loss to the loss of the closest collocation points
64+
losses_xf = loss_seed[knn.indices[:, 0]]
65+
66+
# apply softmax function
67+
q_model = torch.softmax(losses_xf, dim=0)
68+
69+
# obtain 'n' indices sampled from the multinomial distribution
70+
indicies_new = torch.multinomial(q_model[:, 0], n, replacement=True)
71+
72+
# collocation points and corresponding weights
73+
xf = xf[indicies_new]
74+
weight = q_model[indicies_new].detach()
75+
weight = torch.mean(weight, 1, True)
76+
77+
return xf, weight

PINNFramework/Geometry.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import numpy as np
3+
from torch.utils.data import Dataset
4+
5+
class Geometry(Dataset):
6+
def __init__(self, lb, ub, n_points, batch_size, sampler):
7+
"""
8+
Constructor of the Geometry class
9+
10+
Args:
11+
lb (numpy.ndarray): lower bound of the domain.
12+
ub (numpy.ndarray): upper bound of the domain.
13+
n_points (int): the number of sampled points.
14+
batch_size (int): batch size
15+
sampler: instance of the Sampler class.
16+
"""
17+
self.lb = lb
18+
self.ub = ub
19+
self.n_points = n_points
20+
self.batch_size = batch_size
21+
self.sampler = sampler
22+
23+
24+
def __getitem__(self, idx):
25+
raise NotImplementedError("Subclasses should implement '__getitem__' method")
26+
27+
def __len__(self):
28+
raise NotImplementedError("Subclasses should implement '__len__' method")
29+

PINNFramework/HPMLoss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from .PDELoss import PDELoss
22

33
class HPMLoss(PDELoss):
4-
def __init__(self, dataset, name, hpm_input, hpm_model, norm='L2', weight=1.):
4+
def __init__(self, geometry, name, hpm_input, hpm_model, norm='L2', weight=1.):
55
"""
66
Constructor of the HPM loss
7+
78
Args:
8-
dataset (torch.utils.Dataset): dataset that provides the residual points
9+
geometry: instance of the geometry class that defines the domain
910
hpm_input(function): function that calculates the needed input for the HPM model. The hpm_input function
1011
should return a list of tensors, where the last entry is the time_derivative
1112
hpm_model (torch.nn.Module): model for the HPM, represents the underlying PDE
1213
norm: Norm used for calculation PDE loss
1314
weight: Weighting for the loss term
1415
"""
15-
super(HPMLoss, self).__init__(dataset, None, name, norm='L2', weight=1.)
16+
super(HPMLoss, self).__init__(geometry, None, name, norm='L2', weight=1.)
1617
self.hpm_input = hpm_input
1718
self.hpm_model = hpm_model
1819

PINNFramework/LHS_Sampler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import numpy as np
3+
from pyDOE import lhs
4+
from .Sampler import Sampler
5+
6+
7+
class LHSSampler(Sampler):
8+
def __init__(self):
9+
"""
10+
Constructor of the LHSSampler class
11+
"""
12+
super(LHSSampler, self).__init__()
13+
14+
def sample(self, lb, ub, n):
15+
"""Generate 'n' number of sample points in [lb,ub]
16+
17+
Args:
18+
lb (numpy.ndarray): lower bound of the domain.
19+
ub (numpy.ndarray): upper bound of the domain.
20+
n (int): the number of sampled points.
21+
"""
22+
23+
torch.manual_seed(42)
24+
np.random.seed(42)
25+
26+
lb = lb.reshape(1,-1)
27+
ub = ub.reshape(1,-1)
28+
29+
dimension = lb.shape[1]
30+
xf = lb + (ub - lb) * lhs(dimension, n)
31+
return torch.tensor(xf).float()

PINNFramework/ND_Cube.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import torch
3+
from .Geometry import Geometry
4+
5+
class NDCube(Geometry):
6+
def __init__(self, lb, ub, n_points, batch_size, sampler):
7+
"""
8+
Constructor of the NDCube class
9+
10+
Args:
11+
lb (numpy.ndarray): lower bound of the domain.
12+
ub (numpy.ndarray): upper bound of the domain.
13+
n_points (int): the number of sampled points.
14+
batch_size (int): batch size
15+
sampler: instance of the Sampler class.
16+
"""
17+
super(NDCube, self).__init__(lb, ub, n_points, batch_size, sampler)
18+
19+
20+
def __getitem__(self, idx):
21+
"""
22+
Returns data at given index
23+
Args:
24+
idx (int)
25+
"""
26+
self.x = self.sampler.sample(self.lb,self.ub, self.batch_size)
27+
28+
if type(self.x) is tuple:
29+
x, w = self.x
30+
return torch.cat((x, w), 1)
31+
else:
32+
return self.x
33+
34+
def __len__(self):
35+
"""Length of the dataset"""
36+
return self.n_points // self.batch_size

PINNFramework/PDELoss.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22
from torch import Tensor as Tensor
33
from torch.nn import Module as Module
44
from .LossTerm import LossTerm
5+
from .Adaptive_Sampler import AdaptiveSampler
56

67

78
class PDELoss(LossTerm):
8-
def __init__(self, dataset, pde, name, norm='L2', weight=1.):
9+
def __init__(self, geometry, pde, name, norm='L2', weight=1.):
910
"""
1011
Constructor of the PDE Loss
1112
1213
Args:
13-
dataset (torch.utils.Dataset): dataset that provides the residual points
14+
geometry: instance of the geometry class that defines the domain
1415
pde (function): function that represents residual of the PDE
1516
norm: Norm used for calculation PDE loss
1617
weight: Weighting for the loss term
1718
"""
18-
super(PDELoss, self).__init__(dataset, name, norm, weight)
19-
self.dataset = dataset
19+
super(PDELoss, self).__init__(geometry, name, norm, weight)
20+
self.geometry = geometry
2021
self.pde = pde
2122

2223
def __call__(self, x: Tensor, model: Module, **kwargs):
@@ -26,8 +27,17 @@ def __call__(self, x: Tensor, model: Module, **kwargs):
2627
x: residual points
2728
model: model that predicts the solution of the PDE
2829
"""
30+
31+
if isinstance(self.geometry.sampler, AdaptiveSampler):
32+
w = x[:,-1:]
33+
x = x[:,:-1]
34+
2935
x.requires_grad = True # setting requires grad to true in order to calculate
3036
u = model.forward(x)
3137
pde_residual = self.pde(x, u, **kwargs)
32-
zeros = torch.zeros(pde_residual.shape, device=pde_residual.device)
33-
return self.norm(pde_residual, zeros)
38+
39+
if isinstance(self.geometry.sampler, AdaptiveSampler):
40+
return 1 / self.geometry.batch_size * torch.mean(1 / w * pde_residual ** 2)
41+
else:
42+
zeros = torch.zeros(pde_residual.shape, device=pde_residual.device)
43+
return self.norm(pde_residual, zeros)

PINNFramework/PINN.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .PDELoss import PDELoss
1111
from .JoinedDataset import JoinedDataset
1212
from .HPMLoss import HPMLoss
13+
from .Adaptive_Sampler import AdaptiveSampler
1314
from torch.autograd import grad as grad
1415
from PINNFramework.callbacks import CallbackList
1516

@@ -96,11 +97,14 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio
9697
self.is_hpm = False
9798
else:
9899
raise TypeError("PDE loss has to be an instance of a PDE Loss class")
99-
100+
100101
if isinstance(pde_loss, HPMLoss):
101102
self.is_hpm = True
102103
if self.use_gpu:
103104
self.pde_loss.hpm_model.cuda()
105+
106+
if isinstance(pde_loss.geometry.sampler, AdaptiveSampler):
107+
self.pde_loss.geometry.sampler.device = torch.device("cuda" if self.use_gpu else "cpu")
104108

105109
if isinstance(initial_condition, InitialCondition):
106110
self.initial_condition = initial_condition
@@ -110,12 +114,12 @@ def __init__(self, model: torch.nn.Module, input_dimension: int, output_dimensio
110114
if not len(initial_condition.dataset):
111115
raise ValueError("Initial condition dataset is empty")
112116

113-
if not len(pde_loss.dataset):
114-
raise ValueError("PDE dataset is empty")
115-
117+
if not len(pde_loss.geometry):
118+
raise ValueError("Geometry is empty")
119+
116120
joined_datasets = {
117121
initial_condition.name: initial_condition.dataset,
118-
pde_loss.name: pde_loss.dataset
122+
pde_loss.name: pde_loss.geometry
119123
}
120124
if self.rank == 0:
121125
self.loss_log[initial_condition.name] = float(0.0) # adding initial condition to the loss_log
@@ -318,14 +322,14 @@ def pinn_loss(self, training_data, track_gradient=False, annealing=False):
318322
# unpack training data
319323
# ============== PDE LOSS ============== "
320324
if type(training_data[self.pde_loss.name]) is not list:
321-
pde_loss = self.pde_loss(training_data[self.pde_loss.name][0].type(self.dtype), self.model)
325+
pde_loss = self.pde_loss(training_data[self.pde_loss.name][0].type(self.dtype), self.model)
322326
if annealing or track_gradient:
323327
self.loss_gradients_storage[self.pde_loss.name] = self.loss_gradients(pde_loss)
324328
pinn_loss = pinn_loss + self.pde_loss.weight * pde_loss
325329
if self.rank == 0:
326330
self.loss_log[self.pde_loss.name] = pde_loss + self.loss_log[self.pde_loss.name]
327331
else:
328-
raise ValueError("Training Data for PDE data is a single tensor consists of residual points ")
332+
raise ValueError("Training Data for PDE data is either a single tensor consisting of residual points or a concatenation of residual points and corresponding weights ")
329333

330334
# ============== INITIAL CONDITION ============== "
331335
if type(training_data[self.initial_condition.name]) is list:
@@ -673,4 +677,4 @@ def take_snapshot(model, file_path, device, n_points):
673677
# convert all tensors to numpy arrays and save as VTK data
674678
grid = [x.numpy(), y.numpy(), z.numpy()]
675679
output = output.view(n_points).to('cpu').numpy()
676-
imageToVTK(file_path, grid, pointData={"model output": output})
680+
imageToVTK(file_path, grid, pointData={"model output": output})

PINNFramework/Random_Sampler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import numpy as np
3+
from .Sampler import Sampler
4+
5+
6+
class RandomSampler(Sampler):
7+
def __init__(self):
8+
"""
9+
Constructor of the RandomSampler (pseudo random sampler) class
10+
"""
11+
super(RandomSampler, self).__init__()
12+
13+
def sample(self, lb, ub, n):
14+
"""Generate 'n' number of sample points in [lb,ub]
15+
16+
Args:
17+
lb (numpy.ndarray): lower bound of the domain.
18+
ub (numpy.ndarray): upper bound of the domain.
19+
n (int): the number of sampled points.
20+
"""
21+
22+
torch.manual_seed(42)
23+
np.random.seed(42)
24+
25+
lb = lb.reshape(1,-1)
26+
ub = ub.reshape(1,-1)
27+
28+
dimension = lb.shape[1]
29+
xf = np.random.uniform(lb,ub,size=(n, dimension))
30+
return torch.tensor(xf).float()

PINNFramework/Sampler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import numpy as np
3+
from abc import ABC, abstractmethod
4+
5+
class Sampler(ABC):
6+
def __init__(self):
7+
"""
8+
Constructor of the Sampler class
9+
"""
10+
11+
@abstractmethod
12+
def sample(self, lb, ub, n):
13+
"""Generate 'n' number of sample points in [lb,ub]
14+
15+
Args:
16+
lb (numpy.ndarray): lower bound of the domain.
17+
ub (numpy.ndarray): upper bound of the domain.
18+
n (int): the number of sampled points.
19+
"""
20+

PINNFramework/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from .WandB_Logger import WandbLogger
1212
from .TensorBoard_Logger import TensorBoardLogger
1313
from .PINN import PINN
14+
from .Random_Sampler import RandomSampler
15+
from .LHS_Sampler import LHSSampler
16+
from .Adaptive_Sampler import AdaptiveSampler
17+
from .ND_Cube import NDCube
1418

1519
import PINNFramework.models
1620
import PINNFramework.callbacks
@@ -24,6 +28,10 @@
2428
'NeumannBC',
2529
'TimeDerivativeBC',
2630
'PDELoss',
31+
'RandomSampler',
32+
'LHSSampler',
33+
'AdaptiveSampler',
34+
'NDCube'
2735
'HPMLoss',
2836
'PINN',
2937
'models',

0 commit comments

Comments
 (0)