-
Notifications
You must be signed in to change notification settings - Fork 1
/
KANunittest.py
65 lines (55 loc) · 2.79 KB
/
KANunittest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# STILL NEED TO RESOLVE ISSUE and I am working on it ( feel free to fix )
import unittest
import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from KANtf import KANLinear, KAN, extend_grid_tf, B_batch_tf
class TestKANLinear(unittest.TestCase):
def test_initialization(self):
"""Test if KANLinear layer initializes with correct shapes and configurable parameters."""
in_features = 10
out_features = 5
grid_size = 5
spline_order = 3
layer = KANLinear(in_features, out_features, grid_size, spline_order)
self.assertEqual(layer.in_features, in_features)
self.assertEqual(layer.out_features, out_features)
self.assertEqual(layer.grid_size, grid_size)
self.assertEqual(layer.spline_order, spline_order)
self.assertEqual(layer.base_weight.shape, (in_features, out_features))
self.assertEqual(layer.spline_weight.shape, (in_features, out_features, grid_size + spline_order - 1))
def test_forward_pass(self):
"""Test the forward pass computation of KANLinear."""
layer = KANLinear(10, 5)
# Ensure the grid is correctly initialized and used
self.assertTrue(len(layer.grid.shape) == 1, "Grid should be one-dimensional")
input_tensor = tf.random.normal([10, 10]) # batch size of 10, 10 features
output = layer(input_tensor)
self.assertEqual(output.shape, (10, 5))
class TestBSplineFunctions(unittest.TestCase):
def test_extend_grid(self):
"""Test if the grid is extended correctly on both sides."""
grid = tf.constant([0.0, 1.0, 2.0]) # ensure it's 1D as expected
extended_grid = extend_grid_tf(grid, 1)
expected_output = [-1.0, 0.0, 1.0, 2.0, 3.0]
np.testing.assert_array_almost_equal(extended_grid.numpy(), expected_output)
def test_b_spline_basis(self):
"""Test B-spline basis computation for known inputs and grid."""
x = tf.constant([[0.5], [1.5], [2.5]]) # Points to evaluate the spline
grid = tf.constant([0.0, 1.0, 2.0, 3.0]) # Correct one-dimensional grid
# Ensure grid stays one-dimensional
b_spline_values = B_batch_tf(x, grid, k=2)
expected_shape = (3, 4) # or whatever your expected shape is
self.assertEqual(b_spline_values.shape, expected_shape)
class TestKANModel(unittest.TestCase):
def test_model_construction(self):
"""Test the construction of the KAN model."""
layers_config = [
{'in_features': 10, 'out_features': 5},
{'in_features': 5, 'out_features': 3}
]
model = KAN(layers_configurations=layers_config)
self.assertIsInstance(model, tf.keras.models.Sequential)
self.assertEqual(len(model.layers), 2)
if __name__ == '__main__':
unittest.main()