Skip to content

Simplify dots with 1 #638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Feb 8, 2024 · 3 comments · May be fixed by #810
Open

Simplify dots with 1 #638

ricardoV94 opened this issue Feb 8, 2024 · 3 comments · May be fixed by #810

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 8, 2024

Description

We have a local_0_dot_x that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in #637 (comment)

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
    fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))

pytensor.dprint(fn)
dot [id A] 0
 ├─ x [id B]
 └─ [[1.]] [id C]

I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.

@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot):
return False
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]

@Dhruvanshu-Joshi
Copy link
Member

Looks like an interesting issue. We'd just have to replace 0 with x in the local_0_dot_x right?
Here's what I have in mind:

 @register_canonicalize 
 @register_stabilize 
 @node_rewriter([Dot]) 
 def local_1_dot_x(fgraph, node): 
     if not isinstance(node.op, Dot): 
         return False 
  
     x = node.inputs[0] 
     y = node.inputs[1] 
     replace = False 
     try: 
         if get_underlying_scalar_constant_value(x, only_process_constants=True) == 1: 
             replace = True 
             var = y
     except NotScalarConstantError: 
         pass 
  
     try: 
         if get_underlying_scalar_constant_value(y, only_process_constants=True) == 1: 
             replace = True 
             var=x
     except NotScalarConstantError: 
         pass 
  
     if replace: 
         constant_value = constant(get_underlying_scalar_constant_value(var, only_process_constants=True), dtype=node.outputs[0].type.dtype) 
         if x.ndim == 2 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0], y.shape[1])] 
         elif x.ndim == 1 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [alloc(constant_value, y.shape[1])] 
         elif x.ndim == 2 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0])] 
         elif x.ndim == 1 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [constant_value] 

However, I think using constant value might be wrong here. Will I have to replace with the entire var itself? If yes, then is this the correct way of moving forward?

var=assert_op(var,  eq(...)
alloc(var, shape)

@ricardoV94
Copy link
Member Author

No, the rule is slightly different for ones, as it consists of summing the left matrix. Also have to reason about broadcasting.

I suggest playing with numpy to get a feel of what it should do.

@Dhruvanshu-Joshi
Copy link
Member

Ohk.
Just so that I get it correctly, for a given graph say

Sub [id A]
 ├─ dot [id B]
 │  ├─ dot [id C]
 │  │  ├─ Transpose{axes=[1, 0]} [id D] 'A.T'
 │  │  │  └─ A [id E]
 │  │  └─ Neg [id F]
 │  │     └─ x [id G]
 │  └─ [[1.]] [id H]
 └─ dot [id I]
    ├─ A [id E]
    └─ dot [id J]
       ├─ x [id G]
       └─ [[1.]] [id H]

we want the output of the rewrite to be:

Sub [id A]
 ├─ dot [id B]
 │  ├─ Transpose{axes=[1, 0]} [id C] 'A.T'
 │  │  └─ A [id D]
 │  └─ Neg [id E]
 │     └─ x [id F]
 └─ dot [id G]
    ├─ A [id D]
    └─ x [id F]

Is this correct? And if yes, how does summing of left matrices and broadcasting come into picture here?

@Dhruvanshu-Joshi Dhruvanshu-Joshi linked a pull request Jun 7, 2024 that will close this issue
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants