1
+ from typing import Optional , Callable
1
2
from jax import random
2
3
from jax import numpy as np
3
4
from jax .numpy .linalg import cholesky
4
-
5
5
import jax .example_libraries .stax as ostax
6
- import neural_tangents
7
- from neural_tangents import stax
8
6
9
- from pkg_resources import parse_version
10
- if parse_version (neural_tangents .__version__ ) >= parse_version ('0.5.0' ):
11
- from neural_tangents ._src .utils import utils , dataclasses
12
- from neural_tangents ._src .stax .linear import _pool_kernel , Padding
13
- from neural_tangents ._src .stax .linear import _Pooling as Pooling
14
- else :
15
- from neural_tangents .utils import utils , dataclasses
16
- from neural_tangents .stax import _pool_kernel , Padding , Pooling
7
+ from neural_tangents import stax
8
+ from neural_tangents ._src .utils import dataclasses
9
+ from neural_tangents ._src .stax .linear import _pool_kernel , Padding
10
+ from neural_tangents ._src .stax .linear import _Pooling as Pooling
17
11
18
- from sketching import TensorSRHT2 , PolyTensorSRHT
12
+ from experimental . sketching import TensorSRHT2
19
13
""" Implementation for NTK Sketching and Random Features """
20
14
21
15
@@ -50,11 +44,11 @@ def kappa1(x):
50
44
51
45
@dataclasses .dataclass
52
46
class Features :
53
- nngp_feat : np .ndarray
54
- ntk_feat : np .ndarray
47
+ nngp_feat : Optional [ np .ndarray ] = None
48
+ ntk_feat : Optional [ np .ndarray ] = None
55
49
56
- batch_axis : int = dataclasses . field ( pytree_node = False )
57
- channel_axis : int = dataclasses . field ( pytree_node = False )
50
+ batch_axis : int = 0
51
+ channel_axis : int = - 1
58
52
59
53
replace = ... # type: Callable[..., 'Features']
60
54
@@ -72,7 +66,7 @@ def _inputs_to_features(x: np.ndarray,
72
66
return Features (nngp_feat = nngp_feat ,
73
67
ntk_feat = ntk_feat ,
74
68
batch_axis = batch_axis ,
75
- channel_axis = channel_axis )
69
+ channel_axis = channel_axis ) # pytype:disable=wrong-keyword-args
76
70
77
71
78
72
# Modified the serial process of feature map blocks.
@@ -95,7 +89,7 @@ def feature_fn(k, inputs, **kwargs):
95
89
96
90
def DenseFeatures (out_dim : int ,
97
91
W_std : float = 1. ,
98
- b_std : float = None ,
92
+ b_std : float = 1. ,
99
93
parameterization : str = 'ntk' ,
100
94
batch_axis : int = 0 ,
101
95
channel_axis : int = - 1 ):
@@ -114,7 +108,7 @@ def kernel_fn(f: Features, input, **kwargs):
114
108
nngp_feat *= W_std
115
109
ntk_feat *= W_std
116
110
117
- if ntk_feat .ndim == 0 : # check if ntk_feat is empty
111
+ if ntk_feat .ndim == 0 : # check if ntk_feat is empty
118
112
ntk_feat = nngp_feat
119
113
else :
120
114
ntk_feat = np .concatenate ((ntk_feat , nngp_feat ), axis = channel_axis )
@@ -153,20 +147,21 @@ def init_fn(rng, input_shape):
153
147
ts2 = TensorSRHT2 (rng = rng3 ,
154
148
input_dim1 = ntk_feat_shape [- 1 ],
155
149
input_dim2 = feature_dim0 ,
156
- sketch_dim = sketch_dim ).init_sketches ()
150
+ sketch_dim = sketch_dim ).init_sketches () # pytype:disable=wrong-keyword-args
157
151
return (new_nngp_feat_shape , new_ntk_feat_shape ), (W0 , W1 , ts2 )
158
152
159
153
elif method == 'ps' :
160
- rng1 , rng2 , rng3 = random .split (rng , 3 )
161
- # PolySketch algorithm for arc-cosine kernel of order 0.
162
- ps0 = PolyTensorSRHT (rng1 , nngp_feat_shape [- 1 ], poly_sketch_dim0 ,
163
- poly_degree0 )
164
- # PolySketch algorithm for arc-cosine kernel of order 1.
165
- ps1 = PolyTensorSRHT (rng2 , nngp_feat_shape [- 1 ], poly_sketch_dim1 ,
166
- poly_degree1 )
167
- # TensorSRHT of degree 2 for approximating tensor product.
168
- ts2 = TensorSRHT2 (rng3 , ntk_feat_shape [- 1 ], feature_dim0 , sketch_dim )
169
- return (new_nngp_feat_shape , new_ntk_feat_shape ), (ps0 , ps1 , ts2 )
154
+ # rng1, rng2, rng3 = random.split(rng, 3)
155
+ # # PolySketch algorithm for arc-cosine kernel of order 0.
156
+ # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0,
157
+ # poly_degree0)
158
+ # # PolySketch algorithm for arc-cosine kernel of order 1.
159
+ # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1,
160
+ # poly_degree1)
161
+ # # TensorSRHT of degree 2 for approximating tensor product.
162
+ # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim)
163
+ # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2)
164
+ raise NotImplementedError
170
165
171
166
elif method == 'exact' :
172
167
# The exact feature map computation is for debug.
@@ -199,9 +194,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features:
199
194
kappa0_feat ).reshape (input_shape + (- 1 ,))
200
195
201
196
elif method == 'ps' :
202
- ps0 : PolyTensorSRHT = input [0 ]
203
- ps1 : PolyTensorSRHT = input [1 ]
204
- ts2 : TensorSRHT2 = input [2 ]
197
+ # ps0: PolyTensorSRHT = input[0]
198
+ # ps1: PolyTensorSRHT = input[1]
199
+ # ts2: TensorSRHT2 = input[2]
205
200
raise NotImplementedError
206
201
207
202
elif method == 'exact' : # Exact feature extraction via Cholesky decomposition.
@@ -258,7 +253,7 @@ def feature_fn(f, input, **kwargs):
258
253
259
254
nngp_feat = conv2d_feat (nngp_feat , filter_size ) / filter_size * W_std
260
255
261
- if ntk_feat .ndim == 0 : # check if ntk_feat is empty
256
+ if ntk_feat .ndim == 0 : # check if ntk_feat is empty
262
257
ntk_feat = nngp_feat
263
258
else :
264
259
ntk_feat = conv2d_feat (ntk_feat , filter_size ) / filter_size * W_std
0 commit comments