Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit d1a9266

Browse files
committed
Add NTK Random Features and Sketching codes
1 parent c64f307 commit d1a9266

File tree

5 files changed

+13
-7
lines changed

5 files changed

+13
-7
lines changed

experimental/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,15 @@ assert out_feat.ntk_feat.shape == (3, 3)
121121

122122
`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * s**2)` where `s` is the stride size and `d` is the input feature dimension.
123123

124-
To be updated.
125-
126124

127125
## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269)
128126

129-
To be updated.
127+
`features.AvgPoolFeatures` operates the average pooling on features of both NNGP and NTK. It calls [`_pool_kernel`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L3143) function in [Neural Tangents](https://github.com/google/neural-tangents) as a subroutine.
130128

131129
## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304)
132130

133-
To be updated.
131+
`features.FlattenFeatures` makes the features 2-D tensors. Similar to [`Flatten`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L1641) module in [Neural Tangents](https://github.com/google/neural-tangents), the flattened features recale by the square-root of the number of elements. For example, if `nngp_feat` has the shape `N x H x W x C`, it returns a `N x HWC` matrix where all entries are divided by `(H*W*C)**0.5`.
132+
134133

135134
## References
136135
#### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf)

experimental/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from neural_tangents.stax import _pool_kernel, Padding, Pooling
1717

1818
from sketching import TensorSRHT2, PolyTensorSRHT
19-
"""Implementation for NTK Sketching and Random Features"""
19+
""" Implementation for NTK Sketching and Random Features """
2020

2121

2222
def _prod(tuple_):

experimental/sketching.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def sketch(self, x1, x2):
4646
return np.concatenate((out.real, out.imag), 1)
4747

4848

49+
# Function implementation of TensorSRHT of degree 2 (duplicated)
50+
def tensorsrht(x1, x2, rand_inds, rand_signs):
51+
x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]]
52+
x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]]
53+
return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft)
54+
55+
4956
# TensorSRHT of degree p. This operates the same input vectors.
5057
class PolyTensorSRHT:
5158

experimental/test_fc_ntk.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
1010

1111
seed = 1
12-
n = 6
13-
d = 4
12+
n, d = 6, 4
1413

1514
key1, key2 = random.split(random.PRNGKey(seed))
1615
x1 = random.normal(key1, (n, d))

experimental/test_myrtle_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args):
9696
init_nngp_feat_shape = x.shape
9797
init_ntk_feat_shape = (-1, 0)
9898
init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape)
99+
99100
inputs_shape, feat_fn_inputs = init_fn(key2, init_feat_shape)
100101

101102
f0 = _inputs_to_features(x)

0 commit comments

Comments
 (0)