Skip to content

Commit 808136d

Browse files
brandonwillardtwiecki
authored andcommitted
Move shape-related Ops and functions to theano.tensor.shape
1 parent b192515 commit 808136d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1466
-1454
lines changed

doc/extending/ctype.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -565,17 +565,19 @@ default, it will recompile the c code for each process.
565565
Shape and Shape_i
566566
=================
567567

568-
We have 2 generic Ops, Shape and Shape_i, that return the shape of any
569-
Theano Variable that has a shape attribute (Shape_i returns only one of
568+
We have 2 generic `Op`s, `Shape` and `Shape_i`, that return the shape of any
569+
Theano `Variable` that has a shape attribute (`Shape_i` returns only one of
570570
the elements of the shape).
571571

572572

573573
.. code-block:: python
574574

575-
theano.compile.ops.register_shape_c_code(YOUR_TYPE_CLASS, THE_C_CODE, version=())
576-
theano.compile.ops.register_shape_i_c_code(YOUR_TYPE_CLASS, THE_C_CODE, CHECK_INPUT, version=())
575+
from theano.theano.shape import register_shape_c_code, register_shape_i_c_code
577576

578-
The C code works as the ViewOp. Shape_i has the additional ``i`` parameter
577+
register_shape_c_code(YOUR_TYPE_CLASS, THE_C_CODE, version=())
578+
register_shape_i_c_code(YOUR_TYPE_CLASS, THE_C_CODE, CHECK_INPUT, version=())
579+
580+
The C code works as the `ViewOp`. `Shape_i` has the additional ``i`` parameter
579581
that you can use with ``%(i)s``.
580582

581583
In your CHECK_INPUT, you must check that the input has enough dimensions to

doc/library/gradient.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ List of Implemented R op
3636
See the :ref:`gradient tutorial <tutcomputinggrads>` for the R op documentation.
3737

3838
list of ops that support R-op:
39-
* with test [Most is tensor/tests/test_rop.py]
39+
* with test [Most is tests/tensor/test_rop.py]
4040
* SpecifyShape
4141
* MaxAndArgmax
4242
* Subtensor
@@ -52,7 +52,7 @@ list of ops that support R-op:
5252
* Reshape
5353
* Flatten
5454
* DimShuffle
55-
* Scan [In scan/tests/test_scan.test_rop]
55+
* Scan [In tests/scan/test_basic.test_rop]
5656

5757
* without test
5858
* Split

doc/tutorial/shape_info.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ upgrade. Here is the current state of what can be done:
118118

119119
>>> import theano
120120
>>> x = theano.tensor.matrix()
121-
>>> x_specify_shape = theano.compile.ops.specify_shape(x, (2, 2))
121+
>>> x_specify_shape = theano.tensor.specify_shape(x, (2, 2))
122122
>>> f = theano.function([x], (x_specify_shape ** 2).shape)
123123
>>> theano.printing.debugprint(f) # doctest: +NORMALIZE_WHITESPACE
124124
DeepCopyOp [id A] '' 0

tests/compile/test_ops.py

Lines changed: 2 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
11
import pickle
22

33
import numpy as np
4-
import pytest
54

6-
import theano
75
from tests import unittest_tools as utt
86
from theano import function
9-
from theano.compile.ops import Rebroadcast, SpecifyShape, as_op, shape, shape_i
7+
from theano.compile.ops import Rebroadcast, as_op
108
from theano.configdefaults import config
11-
from theano.graph.fg import FunctionGraph
12-
from theano.tensor.opt import ShapeFeature
13-
from theano.tensor.subtensor import Subtensor
14-
from theano.tensor.type import (
15-
TensorType,
16-
dmatrix,
17-
dtensor4,
18-
dvector,
19-
ivector,
20-
matrix,
21-
tensor3,
22-
vector,
23-
)
24-
from theano.tensor.type_other import NoneConst
25-
from theano.typed_list import make_list
9+
from theano.tensor.type import TensorType, dmatrix, dtensor4, dvector
2610

2711

2812
@as_op([dmatrix, dmatrix], dmatrix)
@@ -98,111 +82,6 @@ def test_pickle(self):
9882
assert m2.owner.op == m.owner.op
9983

10084

101-
def test_shape_i_hash():
102-
assert isinstance(theano.tensor.opt.Shape_i(np.int64(1)).__hash__(), int)
103-
104-
105-
class TestSpecifyShape(utt.InferShapeTester):
106-
mode = None
107-
input_type = TensorType
108-
109-
def shortDescription(self):
110-
return None
111-
112-
def test_bad_shape(self):
113-
# Test that at run time we raise an exception when the shape
114-
# is not the one specified
115-
specify_shape = SpecifyShape()
116-
117-
x = vector()
118-
xval = np.random.rand(2).astype(config.floatX)
119-
f = theano.function([x], specify_shape(x, [2]), mode=self.mode)
120-
f(xval)
121-
xval = np.random.rand(3).astype(config.floatX)
122-
with pytest.raises(AssertionError):
123-
f(xval)
124-
125-
assert isinstance(
126-
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
127-
.inputs[0]
128-
.type,
129-
self.input_type,
130-
)
131-
132-
x = matrix()
133-
xval = np.random.rand(2, 3).astype(config.floatX)
134-
f = theano.function([x], specify_shape(x, [2, 3]), mode=self.mode)
135-
assert isinstance(
136-
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
137-
.inputs[0]
138-
.type,
139-
self.input_type,
140-
)
141-
f(xval)
142-
for shape_ in [(1, 3), (2, 2), (5, 5)]:
143-
xval = np.random.rand(*shape_).astype(config.floatX)
144-
with pytest.raises(AssertionError):
145-
f(xval)
146-
147-
def test_bad_number_of_shape(self):
148-
# Test that the number of dimensions provided is good
149-
specify_shape = SpecifyShape()
150-
151-
x = vector()
152-
shape_vec = ivector()
153-
xval = np.random.rand(2).astype(config.floatX)
154-
with pytest.raises(AssertionError):
155-
specify_shape(x, [])
156-
with pytest.raises(AssertionError):
157-
specify_shape(x, [2, 2])
158-
159-
f = theano.function([x, shape_vec], specify_shape(x, shape_vec), mode=self.mode)
160-
assert isinstance(
161-
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
162-
.inputs[0]
163-
.type,
164-
self.input_type,
165-
)
166-
with pytest.raises(AssertionError):
167-
f(xval, [])
168-
with pytest.raises(AssertionError):
169-
f(xval, [2, 2])
170-
171-
x = matrix()
172-
xval = np.random.rand(2, 3).astype(config.floatX)
173-
for shape_ in [(), (1,), (2, 3, 4)]:
174-
with pytest.raises(AssertionError):
175-
specify_shape(x, shape_)
176-
f = theano.function(
177-
[x, shape_vec], specify_shape(x, shape_vec), mode=self.mode
178-
)
179-
assert isinstance(
180-
[
181-
n
182-
for n in f.maker.fgraph.toposort()
183-
if isinstance(n.op, SpecifyShape)
184-
][0]
185-
.inputs[0]
186-
.type,
187-
self.input_type,
188-
)
189-
with pytest.raises(AssertionError):
190-
f(xval, shape_)
191-
192-
def test_infer_shape(self):
193-
rng = np.random.RandomState(3453)
194-
adtens4 = dtensor4()
195-
aivec = ivector()
196-
aivec_val = [3, 4, 2, 5]
197-
adtens4_val = rng.rand(*aivec_val)
198-
self._compile_and_check(
199-
[adtens4, aivec],
200-
[SpecifyShape()(adtens4, aivec)],
201-
[adtens4_val, aivec_val],
202-
SpecifyShape,
203-
)
204-
205-
20685
class TestRebroadcast(utt.InferShapeTester):
20786
def test_rebroadcast(self):
20887
rng = np.random.RandomState(3453)
@@ -227,28 +106,3 @@ def test_rebroadcast(self):
227106
[adtens4_bro_val],
228107
Rebroadcast,
229108
)
230-
231-
232-
@config.change_flags(compute_test_value="raise")
233-
def test_nonstandard_shapes():
234-
a = tensor3(config.floatX)
235-
a.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
236-
b = tensor3(config.floatX)
237-
b.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
238-
239-
tl = make_list([a, b])
240-
tl_shape = shape(tl)
241-
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))
242-
243-
# There's no `FunctionGraph`, so it should return a `Subtensor`
244-
tl_shape_i = shape_i(tl, 0)
245-
assert isinstance(tl_shape_i.owner.op, Subtensor)
246-
assert tl_shape_i.get_test_value() == 2
247-
248-
tl_fg = FunctionGraph([a, b], [tl], features=[ShapeFeature()])
249-
tl_shape_i = shape_i(tl, 0, fgraph=tl_fg)
250-
assert not isinstance(tl_shape_i.owner.op, Subtensor)
251-
assert tl_shape_i.get_test_value() == 2
252-
253-
none_shape = shape(NoneConst)
254-
assert np.array_equal(none_shape.get_test_value(), [])

tests/gpuarray/test_basic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
TestReshape,
1515
)
1616
from tests.tensor.utils import rand, safe_make_node
17-
from theano.compile.ops import Shape_i
1817
from theano.gpuarray.basic_ops import (
1918
GpuAlloc,
2019
GpuAllocEmpty,
@@ -36,6 +35,7 @@
3635
from theano.gpuarray.type import GpuArrayType, get_context, gpuarray_shared_constructor
3736
from theano.tensor.basic import Alloc, Split, alloc
3837
from theano.tensor.opt import MakeVector
38+
from theano.tensor.shape import Shape, Shape_i
3939
from theano.tensor.type import TensorType, fmatrix, iscalar, lscalar, matrix
4040

4141

@@ -352,7 +352,7 @@ def test_shape():
352352
topo = f.maker.fgraph.toposort()
353353
assert np.all(f(v) == (3, 4, 5))
354354
assert len(topo) == 1
355-
assert isinstance(topo[0].op, theano.compile.ops.Shape)
355+
assert isinstance(topo[0].op, Shape)
356356

357357

358358
def test_gpu_contiguous():

tests/gpuarray/test_dnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from theano.tensor.nnet.corr import CorrMM
4646
from theano.tensor.nnet.corr3d import Corr3dMM
47+
from theano.tensor.shape import reshape
4748
from theano.tensor.signal.pool import (
4849
AveragePoolGrad,
4950
MaxPoolGrad,
@@ -2813,7 +2814,7 @@ def test_dnn_spatialtf():
28132814

28142815
def spatialtf_cpu(inp, theta, scale_height, scale_width, border_mode="nearest"):
28152816
num_batch, num_channels, height, width = inp.shape
2816-
theta = tt.reshape(theta, (-1, 2, 3))
2817+
theta = reshape(theta, (-1, 2, 3))
28172818

28182819
# grid of (x_t, y_t, 1), eq (1) in ref [1]
28192820
out_height = tt.cast(tt.ceil(height * scale_height), "int64")
@@ -2832,7 +2833,7 @@ def spatialtf_cpu(inp, theta, scale_height, scale_width, border_mode="nearest"):
28322833
input_dim, x_s_flat, y_s_flat, out_height, out_width, border_mode
28332834
)
28342835

2835-
output = tt.reshape(
2836+
output = reshape(
28362837
input_transformed, (num_batch, out_height, out_width, num_channels)
28372838
)
28382839
output = output.dimshuffle(0, 3, 1, 2) # dimshuffle to conv format

tests/gpuarray/test_opt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,16 @@
5555

5656

5757
def _check_stack_trace(thing):
58+
from theano.tensor.shape import Shape, Shape_i
59+
5860
def _ops_to_check(op):
5961
if not isinstance(op, theano.graph.op.Op):
6062
op = op.op # assume it is an apply node
6163
return not isinstance(
6264
op,
6365
(
64-
theano.compile.ops.Shape_i,
65-
theano.compile.ops.Shape,
66+
Shape_i,
67+
Shape,
6668
theano.compile.ops.DeepCopyOp,
6769
theano.tensor.opt.MakeVector,
6870
theano.tensor.subtensor.Subtensor,

tests/gpuarray/test_type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from theano.compile import DeepCopyOp, Rebroadcast, ViewOp
1616
from theano.configdefaults import config
1717
from theano.gpuarray.type import GpuArrayType, gpuarray_shared_constructor
18+
from theano.tensor.shape import specify_shape
1819
from theano.tensor.type import row
1920

2021

@@ -80,7 +81,7 @@ def test_specify_shape():
8081
for dtype in ["float16", "float32"]:
8182
a = rand_gpuarray(20, dtype=dtype)
8283
g = GpuArrayType(dtype=dtype, broadcastable=(False,))("g")
83-
f = theano.function([g], theano.compile.ops.specify_shape(g, [20]))
84+
f = theano.function([g], specify_shape(g, [20]))
8485
f(a)
8586

8687

tests/link/test_jax.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from theano.graph.fg import FunctionGraph
88
from theano.graph.optdb import Query
99
from theano.link.jax import JAXLinker
10-
from theano.tensor import basic as tt_basic
1110
from theano.tensor import blas as tt_blas
1211
from theano.tensor import elemwise as tt_elemwise
1312
from theano.tensor import extra_ops as tt_extra_ops
@@ -16,6 +15,7 @@
1615
from theano.tensor import opt as tt_opt
1716
from theano.tensor import slinalg as tt_slinalg
1817
from theano.tensor import subtensor as tt_subtensor
18+
from theano.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
1919
from theano.tensor.type import (
2020
dscalar,
2121
dvector,
@@ -139,35 +139,37 @@ def compare_shape_dtype(x, y):
139139
compare_jax_and_py(x_fg, [np.ones(10, dtype=theano.config.floatX)])
140140

141141

142-
def test_jax_compile_ops():
143-
x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1))
144-
x_fg = FunctionGraph([], [x])
145-
146-
compare_jax_and_py(x_fg, [])
147-
142+
def test_jax_shape_ops():
148143
x_np = np.zeros((20, 3))
149-
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
144+
x = Shape()(tt.as_tensor_variable(x_np))
150145
x_fg = FunctionGraph([], [x])
151146

152147
compare_jax_and_py(x_fg, [], must_be_device_array=False)
153148

154-
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
149+
x = Shape_i(1)(tt.as_tensor_variable(x_np))
155150
x_fg = FunctionGraph([], [x])
156151

157152
compare_jax_and_py(x_fg, [], must_be_device_array=False)
158153

159-
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
154+
x = SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
160155
x_fg = FunctionGraph([], [x])
161156

162157
compare_jax_and_py(x_fg, [])
163158

164159
with theano.config.change_flags(compute_test_value="off"):
165-
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3))
160+
x = SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3))
166161
x_fg = FunctionGraph([], [x])
167162

168163
with pytest.raises(AssertionError):
169164
compare_jax_and_py(x_fg, [])
170165

166+
167+
def test_jax_compile_ops():
168+
x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1))
169+
x_fg = FunctionGraph([], [x])
170+
171+
compare_jax_and_py(x_fg, [])
172+
171173
x_np = np.zeros((20, 1, 1))
172174
x = theano.compile.ops.Rebroadcast((0, False), (1, True), (2, False))(
173175
tt.as_tensor_variable(x_np)
@@ -650,13 +652,13 @@ def test_jax_MakeVector():
650652

651653
def test_jax_Reshape():
652654
a = vector("a")
653-
x = tt_basic.reshape(a, (2, 2))
655+
x = reshape(a, (2, 2))
654656
x_fg = FunctionGraph([a], [x])
655657
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
656658

657659
# Test breaking "omnistaging" changes in JAX.
658660
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
659-
x = tt_basic.reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
661+
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
660662
x_fg = FunctionGraph([a], [x])
661663
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
662664

@@ -665,7 +667,7 @@ def test_jax_Reshape():
665667
def test_jax_Reshape_nonconcrete():
666668
a = vector("a")
667669
b = iscalar("b")
668-
x = tt_basic.reshape(a, (b, b))
670+
x = reshape(a, (b, b))
669671
x_fg = FunctionGraph([a, b], [x])
670672
compare_jax_and_py(
671673
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX), 2]

0 commit comments

Comments
 (0)