|
50 | 50 | from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
|
51 | 51 | from pytensor.scan.op import Scan
|
52 | 52 | from pytensor.tensor.exceptions import NotScalarConstantError
|
53 |
| -from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div |
| 53 | +from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div |
54 | 54 | from pytensor.tensor.rewriting.basic import (
|
55 | 55 | register_specialize,
|
56 | 56 | register_stabilize,
|
@@ -384,6 +384,46 @@ def measurable_div_to_reciprocal_product(fgraph, node):
|
384 | 384 | return [at.mul(numerator, at.reciprocal(denominator))]
|
385 | 385 |
|
386 | 386 |
|
| 387 | +@node_rewriter([neg]) |
| 388 | +def measurable_neg_to_product(fgraph, node): |
| 389 | + """Convert negation of `MeasurableVariable`s to product with `-1`.""" |
| 390 | + |
| 391 | + inp = node.inputs[0] |
| 392 | + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): |
| 393 | + return None |
| 394 | + |
| 395 | + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) |
| 396 | + if rv_map_feature is None: |
| 397 | + return None # pragma: no cover |
| 398 | + |
| 399 | + # Only apply this rewrite if the variable is unvalued |
| 400 | + if inp in rv_map_feature.rv_values: |
| 401 | + return None # pragma: no cover |
| 402 | + |
| 403 | + return [at.mul(inp, -1.0)] |
| 404 | + |
| 405 | + |
| 406 | +@node_rewriter([sub]) |
| 407 | +def measurable_sub_to_neg(fgraph, node): |
| 408 | + """Convert subtraction involving `MeasurableVariable`s to addition with neg""" |
| 409 | + measurable_vars = [ |
| 410 | + var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable)) |
| 411 | + ] |
| 412 | + if not measurable_vars: |
| 413 | + return None # pragma: no cover |
| 414 | + |
| 415 | + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) |
| 416 | + if rv_map_feature is None: |
| 417 | + return None # pragma: no cover |
| 418 | + |
| 419 | + # Only apply this rewrite if there is one unvalued MeasurableVariable involved |
| 420 | + if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars): |
| 421 | + return None # pragma: no cover |
| 422 | + |
| 423 | + minuend, subtrahend = node.inputs |
| 424 | + return [at.add(minuend, at.neg(subtrahend))] |
| 425 | + |
| 426 | + |
387 | 427 | @node_rewriter([exp, log, add, mul, reciprocal])
|
388 | 428 | def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
|
389 | 429 | """Find measurable transformations from Elemwise operators."""
|
@@ -475,6 +515,19 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
|
475 | 515 | "transform",
|
476 | 516 | )
|
477 | 517 |
|
| 518 | +measurable_ir_rewrites_db.register( |
| 519 | + "measurable_neg_to_product", |
| 520 | + measurable_neg_to_product, |
| 521 | + "basic", |
| 522 | + "transform", |
| 523 | +) |
| 524 | + |
| 525 | +measurable_ir_rewrites_db.register( |
| 526 | + "measurable_sub_to_neg", |
| 527 | + measurable_sub_to_neg, |
| 528 | + "basic", |
| 529 | + "transform", |
| 530 | +) |
478 | 531 |
|
479 | 532 | measurable_ir_rewrites_db.register(
|
480 | 533 | "find_measurable_transforms",
|
|
0 commit comments