-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbucket_data_helper.py
31 lines (24 loc) · 980 Bytes
/
bucket_data_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
class bucket_data:
def __init__(self, data, batch_token = 16000):
self.data = data
self.batch_token = batch_token
def get_dataset(self, bucket_shuffle=False, dataset_shuffle=False):
# bucket_shuffle: 버켓별로 셔플.
# dataset_shuffle: data_list 셔플
data_list = []
for key in self.data:
batch_size = self.batch_token // sum(key)
if bucket_shuffle is True:
source, target = self.data[key]
indices = np.arange(len(source))
np.random.shuffle(indices)
self.data[key] = [source[indices], target[indices]]
for i in range( int(np.ceil(len(self.data[key][0])/batch_size)) ):
bucket_data = self.data[key]
batch_source = bucket_data[0][i*batch_size : (i+1)*batch_size]
batch_target = bucket_data[1][i*batch_size : (i+1)*batch_size]
data_list.append([batch_source, batch_target])
if dataset_shuffle is True:
np.random.shuffle(data_list)
return data_list