-
Notifications
You must be signed in to change notification settings - Fork 132
Rewriting the kron function using JAX implementation #684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
8875fd0
9eb50a6
be52b59
52ed8c3
1611f9a
5aef097
5299b75
3e0c79c
a7902fd
e3b1aaa
3ea1cdb
7fa03cf
5104494
35996a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
det, | ||
eig, | ||
eigh, | ||
kron, | ||
lstsq, | ||
matrix_dot, | ||
matrix_inverse, | ||
|
@@ -580,3 +581,41 @@ def test_eval(self): | |
t_binv1 = tf_b1(self.b1) | ||
assert _allclose(t_binv, n_binv) | ||
assert _allclose(t_binv1, n_binv1) | ||
|
||
|
||
class TestKron(utt.InferShapeTester): | ||
rng = np.random.default_rng(43) | ||
|
||
def setup_method(self): | ||
self.op = kron | ||
super().setup_method() | ||
|
||
def test_perform(self): | ||
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here? Use pytest.mark.parametrize? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hey! could you briefly explain how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a decorator that passes keyword arguments into a test function. For example:
This will run 3 tests, one for each value of You can pass multiple parameters like this:
This will make 3 tests, all of which will pass. Or you can make the full cartesian product between parameters by stacking decorators:
This will make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you this is really helpful. safe to say these are just fancy for loops? 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but with better outputs when you run pytest. It splits each parameterization into its own test, and so you can know exactly which combination of parameters is failing. If you have a loop, it just tells you pass/fail, not at which step of the loop. Also it means you get more green dots when you run pytest, which is obviously extremely important |
||
x = tensor(dtype="floatX", shape=(None,) * len(shp0)) | ||
a = np.asarray(self.rng.random(shp0)).astype(config.floatX) | ||
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]: | ||
if len(shp0) + len(shp1) == 2: | ||
continue | ||
y = tensor(dtype="floatX", shape=(None,) * len(shp1)) | ||
f = function([x, y], kron(x, y)) | ||
b = self.rng.random(shp1).astype(config.floatX) | ||
out = f(a, b) | ||
# Using the np.kron to compare outputs | ||
np_val = np.kron(a, b) | ||
np.testing.assert_allclose(out, np_val) | ||
|
||
def test_kron_commutes_with_inv(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since no code is reused, why not parametrize? |
||
for shp0, shp1 in zip( | ||
[(2, 3), (2, 3), (2, 4, 3)], [(6, 7), (4, 3, 5), (4, 3, 5)] | ||
): | ||
if len(shp0) == 3 or len(shp1) == 3: | ||
continue | ||
tanish1729 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x = tensor(dtype="floatX", shape=(None,) * len(shp0)) | ||
a = np.asarray(self.rng.random(shp0)).astype(config.floatX) | ||
y = tensor(dtype="floatX", shape=(None,) * len(shp1)) | ||
b = self.rng.random(shp1).astype(config.floatX) | ||
lhs_f = function([x, y], pinv(kron(x, y))) | ||
rhs_f = function([x, y], kron(pinv(x), pinv(y))) | ||
atol = 1e-4 if config.floatX == "float32" else 1e-12 | ||
np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol) |
Uh oh!
There was an error while loading. Please reload this page.