Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit aa12545

Browse files
committed
Resolve pytype tests
1 parent 9dc3536 commit aa12545

File tree

5 files changed

+42
-42
lines changed

5 files changed

+42
-42
lines changed

experimental/__init__.py

Whitespace-only changes.

experimental/features.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
1+
from typing import Optional, Callable
12
from jax import random
23
from jax import numpy as np
34
from jax.numpy.linalg import cholesky
4-
55
import jax.example_libraries.stax as ostax
6-
import neural_tangents
7-
from neural_tangents import stax
86

9-
from pkg_resources import parse_version
10-
if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
11-
from neural_tangents._src.utils import utils, dataclasses
12-
from neural_tangents._src.stax.linear import _pool_kernel, Padding
13-
from neural_tangents._src.stax.linear import _Pooling as Pooling
14-
else:
15-
from neural_tangents.utils import utils, dataclasses
16-
from neural_tangents.stax import _pool_kernel, Padding, Pooling
7+
from neural_tangents import stax
8+
from neural_tangents._src.utils import dataclasses
9+
from neural_tangents._src.stax.linear import _pool_kernel, Padding
10+
from neural_tangents._src.stax.linear import _Pooling as Pooling
1711

18-
from sketching import TensorSRHT2, PolyTensorSRHT
12+
from experimental.sketching import TensorSRHT2
1913
""" Implementation for NTK Sketching and Random Features """
2014

2115

@@ -50,11 +44,11 @@ def kappa1(x):
5044

5145
@dataclasses.dataclass
5246
class Features:
53-
nngp_feat: np.ndarray
54-
ntk_feat: np.ndarray
47+
nngp_feat: Optional[np.ndarray] = None
48+
ntk_feat: Optional[np.ndarray] = None
5549

56-
batch_axis: int = dataclasses.field(pytree_node=False)
57-
channel_axis: int = dataclasses.field(pytree_node=False)
50+
batch_axis: int = 0
51+
channel_axis: int = -1
5852

5953
replace = ... # type: Callable[..., 'Features']
6054

@@ -72,7 +66,7 @@ def _inputs_to_features(x: np.ndarray,
7266
return Features(nngp_feat=nngp_feat,
7367
ntk_feat=ntk_feat,
7468
batch_axis=batch_axis,
75-
channel_axis=channel_axis)
69+
channel_axis=channel_axis) # pytype:disable=wrong-keyword-args
7670

7771

7872
# Modified the serial process of feature map blocks.
@@ -95,7 +89,7 @@ def feature_fn(k, inputs, **kwargs):
9589

9690
def DenseFeatures(out_dim: int,
9791
W_std: float = 1.,
98-
b_std: float = None,
92+
b_std: float = 1.,
9993
parameterization: str = 'ntk',
10094
batch_axis: int = 0,
10195
channel_axis: int = -1):
@@ -114,7 +108,7 @@ def kernel_fn(f: Features, input, **kwargs):
114108
nngp_feat *= W_std
115109
ntk_feat *= W_std
116110

117-
if ntk_feat.ndim == 0: # check if ntk_feat is empty
111+
if ntk_feat.ndim == 0: # check if ntk_feat is empty
118112
ntk_feat = nngp_feat
119113
else:
120114
ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis)
@@ -153,20 +147,21 @@ def init_fn(rng, input_shape):
153147
ts2 = TensorSRHT2(rng=rng3,
154148
input_dim1=ntk_feat_shape[-1],
155149
input_dim2=feature_dim0,
156-
sketch_dim=sketch_dim).init_sketches()
150+
sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args
157151
return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2)
158152

159153
elif method == 'ps':
160-
rng1, rng2, rng3 = random.split(rng, 3)
161-
# PolySketch algorithm for arc-cosine kernel of order 0.
162-
ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0,
163-
poly_degree0)
164-
# PolySketch algorithm for arc-cosine kernel of order 1.
165-
ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1,
166-
poly_degree1)
167-
# TensorSRHT of degree 2 for approximating tensor product.
168-
ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim)
169-
return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2)
154+
# rng1, rng2, rng3 = random.split(rng, 3)
155+
# # PolySketch algorithm for arc-cosine kernel of order 0.
156+
# ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0,
157+
# poly_degree0)
158+
# # PolySketch algorithm for arc-cosine kernel of order 1.
159+
# ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1,
160+
# poly_degree1)
161+
# # TensorSRHT of degree 2 for approximating tensor product.
162+
# ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim)
163+
# return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2)
164+
raise NotImplementedError
170165

171166
elif method == 'exact':
172167
# The exact feature map computation is for debug.
@@ -199,9 +194,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features:
199194
kappa0_feat).reshape(input_shape + (-1,))
200195

201196
elif method == 'ps':
202-
ps0: PolyTensorSRHT = input[0]
203-
ps1: PolyTensorSRHT = input[1]
204-
ts2: TensorSRHT2 = input[2]
197+
# ps0: PolyTensorSRHT = input[0]
198+
# ps1: PolyTensorSRHT = input[1]
199+
# ts2: TensorSRHT2 = input[2]
205200
raise NotImplementedError
206201

207202
elif method == 'exact': # Exact feature extraction via Cholesky decomposition.
@@ -258,7 +253,7 @@ def feature_fn(f, input, **kwargs):
258253

259254
nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std
260255

261-
if ntk_feat.ndim == 0: # check if ntk_feat is empty
256+
if ntk_feat.ndim == 0: # check if ntk_feat is empty
262257
ntk_feat = nngp_feat
263258
else:
264259
ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std

experimental/sketching.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from jax import random
22
from jax import numpy as np
3-
from neural_tangents._src.utils import utils, dataclasses
4-
from neural_tangents._src.utils.typing import Optional
3+
from neural_tangents._src.utils import dataclasses
4+
from typing import Optional, Callable
55

66

77
# TensorSRHT of degree 2. This version allows different input vectors.
@@ -20,9 +20,9 @@ class TensorSRHT2:
2020
rand_inds1: Optional[np.ndarray] = None
2121
rand_inds2: Optional[np.ndarray] = None
2222

23-
replace = ...
23+
replace = ... # type: Callable[..., 'TensorSRHT2']
2424

25-
def init_sketches(self):
25+
def init_sketches(self) -> 'TensorSRHT2':
2626
rng1, rng2, rng3, rng4 = random.split(self.rng, 4)
2727
rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1
2828
rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1
@@ -53,7 +53,8 @@ def tensorsrht(x1, x2, rand_inds, rand_signs):
5353
return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft)
5454

5555

56-
# TensorSRHT of degree p. This operates the same input vectors.
56+
# pytype: disable=attribute-error
57+
# TODO: Improve faster TensorSRHT.
5758
class PolyTensorSRHT:
5859

5960
def __init__(self, rng, input_dim, sketch_dim, coeffs):
@@ -133,3 +134,4 @@ def sketch(self, x):
133134
p = p // 2
134135
U[j] = V[log_degree - 1][0, :, :].clone()
135136
return U
137+
# pytype: enable=attribute-error

experimental/test_fc_ntk.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from jax import random
33
from jax.config import config
44
from jax import jit
5+
import sys
6+
sys.path.append("./")
57

68
config.update("jax_enable_x64", True)
79
from neural_tangents import stax
810

9-
from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
11+
from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
1012

1113
seed = 1
1214
n, d = 6, 4

experimental/test_myrtle_network.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
2-
32
os.environ['CUDA_VISIBLE_DEVICES'] = ''
3+
import sys
4+
sys.path.append("./")
45
import functools
56
from numpy.linalg import norm
67
from jax.config import config
@@ -12,7 +13,7 @@
1213
from jax import random
1314

1415
from neural_tangents import stax
15-
from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
16+
from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
1617

1718
layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
1819
width = 1

0 commit comments

Comments
 (0)