diff --git a/pytensor/sandbox/__init__.py b/pytensor/sandbox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pytensor/sandbox/linalg/__init__.py b/pytensor/sandbox/linalg/__init__.py deleted file mode 100644 index e4428ca21f..0000000000 --- a/pytensor/sandbox/linalg/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from pytensor.sandbox.linalg.ops import spectral_radius_bound diff --git a/pytensor/sandbox/minimal.py b/pytensor/sandbox/minimal.py deleted file mode 100644 index c0236e6cc7..0000000000 --- a/pytensor/sandbox/minimal.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np - -from pytensor.graph.basic import Apply -from pytensor.graph.op import Op -from pytensor.tensor.type import lscalar - - -class Minimal(Op): - # TODO : need description for class - - # if the Op has any attributes, consider using them in the eq function. - # If two Apply nodes have the same inputs and the ops compare equal... - # then they will be MERGED so they had better have computed the same thing! - - __props__ = () - - def __init__(self): - # If you put things here, think about whether they change the outputs - # computed by # self.perform() - # - If they do, then you should take them into consideration in - # __eq__ and __hash__ - # - If they do not, then you should not use them in - # __eq__ and __hash__ - - super().__init__() - - def make_node(self, *args): - # HERE `args` must be PYTENSOR VARIABLES - return Apply(op=self, inputs=args, outputs=[lscalar()]) - - def perform(self, node, inputs, out_): - (output,) = out_ - # HERE `inputs` are PYTHON OBJECTS - - # do what you want here, - # but do not modify any of the arguments [inplace]. - print("perform got %i arguments" % len(inputs)) - - print("Max of input[0] is ", np.max(inputs[0])) - - # return some computed value. - # do not return something that is aliased to one of the inputs. - output[0] = np.asarray(0, dtype="int64") - - -minimal = Minimal() diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 1282cabae5..b276d7339b 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1,4 +1,4 @@ -"""Ops and optimizations for using BLAS calls +"""Ops for using BLAS calls BLAS = Basic Linear Algebra Subroutines Learn more about BLAS here: @@ -71,60 +71,10 @@ that system. -Optimizations -============= - -The optimization pipeline works something like this: - - 1. identify dot22 from dot - 2. identify gemm from dot22 - 3. identify dot22scalar from dot22 that are not gemm - 4. specialize gemm to gemv where applicable - 5. specialize gemm to ger where applicable - 6. specialize dot22 -> gemv or ger where applicable - -:note: GEMM is the most canonical BLAS signature that we deal with so far, it - would be good to turn most things into GEMM (dot, inner, outer, dot22, - dot22scalar), and then to specialize from gemm to the various other L2 and - L3 operations. - -Identify Dot22 --------------- - -Numpy's dot supports arguments that are of any rank, and we should support that -too (just for compatibility). The BLAS optimizations work with Dot Ops whose -inputs are each either vector or matrix. So the first part of the optimization -pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be -transformed further, but they will get implemented by a BLAS call. - -More precisely, Dot nodes whose inputs are all vectors or matrices and whose -inputs both have the same dtype, and whose dtype is float or complex, become -Dot22. This is implemented in `local_dot_to_dot22`. - - -Identify Gemm from Dot22 ------------------------- - -This is complicated, done in GemmOptimizer. - -Identify Dot22Scalar from Dot22 -------------------------------- - -Dot22 Ops that remain after the GemmOptimizer is done have not -qualified as GEMM Ops. Still they might be scaled by a factor, in -which case we use Dot22Scalar which is like Gemm, but without the b -and the Z. In the future it would be good to merge this into the -GemmOptimizer. - -Specialize Gemm to Gemv ------------------------ - -If arguments to GEMM are dimshuffled vectors, then we can use GEMV -instead. This optimization is `local_gemm_to_gemv`. +Optimizations associated with these BLAS Ops are in tensor.rewriting.blas """ -import copy import logging import os import time @@ -140,38 +90,20 @@ from typing import Tuple import pytensor.scalar -from pytensor.compile.mode import optdb from pytensor.configdefaults import config from pytensor.graph.basic import Apply, view_roots -from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate from pytensor.graph.op import Op -from pytensor.graph.rewriting.basic import ( - EquilibriumGraphRewriter, - GraphRewriter, - copy_stack_trace, - in2out, - node_rewriter, -) -from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.printing import FunctionPrinter, debugprint, pprint +from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t from pytensor.tensor import basic as at from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, add, mul, neg, sub -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import add, mul, neg, sub from pytensor.tensor.shape import specify_broadcastable -from pytensor.tensor.type import ( - DenseTensorType, - TensorType, - integer_dtypes, - tensor, - values_eq_approx_remove_inf_nan, -) +from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.utils import memoize @@ -1512,150 +1444,6 @@ def _gemm_from_node2(fgraph, node): return None, t1 - t0, 0, 0 -class GemmOptimizer(GraphRewriter): - """Graph optimizer for inserting Gemm operations.""" - - def __init__(self): - super().__init__() - self.warned = False - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_replacement_didn_t_remove = 0 - nb_inconsistency_make = 0 - nb_inconsistency_replace = 0 - time_canonicalize = 0 - time_factor_can = 0 - time_factor_list = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - - def on_import(new_node): - if new_node is not node: - nodelist.append(new_node) - - u = pytensor.graph.rewriting.basic.DispatchingFeature( - on_import, None, None, name="GemmOptimizer" - ) - fgraph.attach_feature(u) - while did_something: - nb_iter += 1 - t0 = time.perf_counter() - nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs) - time_toposort += time.perf_counter() - t0 - did_something = False - nodelist.reverse() - for node in nodelist: - if not ( - isinstance(node.op, Elemwise) - and isinstance( - node.op.scalar_op, - ( - pytensor.scalar.Add, - pytensor.scalar.Sub, - pytensor.scalar.Neg, - pytensor.scalar.Mul, - ), - ) - ): - continue - if node not in fgraph.apply_nodes: - # This mean that we already removed this node from - # the graph - continue - try: - new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node) - time_canonicalize += time1 - time_factor_can += time2 - time_factor_list += time3 - except InconsistencyError: - nb_inconsistency_make += 1 - continue - if new_outputs: - new_outputs, old_dot22 = new_outputs - assert len(new_outputs) == len(node.outputs) - new_outputs[ - 0 - ].tag.values_eq_approx = values_eq_approx_remove_inf_nan - try: - fgraph.replace_all_validate_remove( - list(zip(node.outputs, new_outputs)), - [old_dot22], - reason="GemmOptimizer", - # For now we disable the warning as we know case - # that we need to fix. - warn=False, # warn=not self.warned - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - # TODO: retry other applications of gemm (see comment - # in _gemm_from_node) - nb_inconsistency_replace += 1 - except ReplacementDidNotRemoveError: - nb_replacement_didn_t_remove += 1 - self.warned = True - fgraph.remove_feature(u) - if fgraph.profile: - validate_time = fgraph.profile.validate_time - validate_before - callback_time = fgraph.execute_callbacks_time - callback_before - callbacks_time = {} - for k, v in fgraph.execute_callbacks_times.items(): - if k in callbacks_before: - callbacks_time[k] = v - callbacks_before[k] - else: - callbacks_time[k] = v - else: - validate_time = None - callback_time = None - callbacks_time = {} - - return ( - self, - nb_iter, - nb_replacement, - nb_replacement_didn_t_remove, - nb_inconsistency_make, - nb_inconsistency_replace, - time_canonicalize, - time_factor_can, - time_factor_list, - time_toposort, - validate_time, - callback_time, - callbacks_time, - ) - - @classmethod - def print_profile(cls, stream, prof, level=0): - blanc = " " * level - print(blanc, cls.__name__, file=stream) - print(blanc, " nb_iter", prof[1], file=stream) - print(blanc, " nb_replacement", prof[2], file=stream) - print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream) - print(blanc, " nb_inconsistency_make", prof[4], file=stream) - print(blanc, " nb_inconsistency_replace", prof[5], file=stream) - print(blanc, " time_canonicalize", prof[6], file=stream) - print(blanc, " time_factor_can", prof[7], file=stream) - print(blanc, " time_factor_list", prof[8], file=stream) - print(blanc, " time_toposort", prof[9], file=stream) - print(blanc, " validate_time", prof[10], file=stream) - print(blanc, " callback_time", prof[11], file=stream) - if prof[11] > 1: - print(blanc, " callbacks_time", file=stream) - for i in sorted(prof[12].items(), key=lambda a: a[1]): - if i[1] > 0: - print(i) - - class Dot22(GemmRelated): """Compute a matrix-matrix product. @@ -1750,207 +1538,6 @@ def c_code_cache_version(self): _dot22 = Dot22() -@node_rewriter([Dot]) -def local_dot_to_dot22(fgraph, node): - # This works for tensor.outer too because basic.outer is a macro that - # produces a dot(dimshuffle,dimshuffle) of form 4 below - if not isinstance(node.op, Dot): - return - - if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): - return False - - x, y = node.inputs - if y.type.dtype != x.type.dtype: - # TODO: upcast one so the types match - _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") - return - - if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): - if x.ndim == 2 and y.ndim == 2: - new_out = [_dot22(*node.inputs)] - elif x.ndim == 2 and y.ndim == 1: - new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] - elif x.ndim == 1 and y.ndim == 2: - new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] - elif x.ndim == 1 and y.ndim == 1: - new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") - - -@node_rewriter([gemm_no_inplace], inplace=True) -def local_inplace_gemm(fgraph, node): - if node.op == gemm_no_inplace: - new_out = [gemm_inplace(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemv_no_inplace], inplace=True) -def local_inplace_gemv(fgraph, node): - if node.op == gemv_no_inplace: - new_out = [gemv_inplace(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([ger], inplace=True) -def local_inplace_ger(fgraph, node): - if node.op == ger: - new_out = [ger_destructive(*node.inputs)] - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemm_no_inplace]) -def local_gemm_to_gemv(fgraph, node): - """GEMM acting on row or column matrices -> GEMV.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if z.broadcastable == x.broadcastable == (True, False): - r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) - new_out = [r.dimshuffle("x", 0)] - elif z.broadcastable == y.broadcastable == (False, True): - r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) - new_out = [r.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -@node_rewriter([gemm_no_inplace]) -def local_gemm_to_ger(fgraph, node): - """GEMM computing an outer-product -> GER.""" - if node.op == gemm_no_inplace: - z, a, x, y, b = node.inputs - if x.broadcastable[1] and y.broadcastable[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - try: - bval = at.get_underlying_scalar_constant_value(b) - except NotScalarConstantError: - # b isn't a constant, GEMM is doing useful pre-scaling - return - - if bval == 1: # best case a natural GER - rval = ger(z, a, xv, yv) - new_out = [rval] - elif bval == 0: # GER on zeros_like should be faster than GEMM - zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) - rval = ger(zeros, a, xv, yv) - new_out = [rval] - else: - # if bval is another constant, then z is being usefully - # pre-scaled and GER isn't really the right tool for the job. - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline -# working -@node_rewriter([_dot22]) -def local_dot22_to_ger_or_gemv(fgraph, node): - """dot22 computing an outer-product -> GER.""" - if node.op == _dot22: - x, y = node.inputs - xb = x.broadcastable - yb = y.broadcastable - one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype)) - zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype)) - if xb[1] and yb[0]: - # x and y are both vectors so this might qualifies for a GER - xv = x.dimshuffle(0) - yv = y.dimshuffle(1) - zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) - rval = ger(zeros, one, xv, yv) - new_out = [rval] - elif xb[0] and yb[1]: - # x and y are both vectors so this qualifies for a sdot / ddot - # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 - xv = x.dimshuffle(1) - zeros = at.AllocEmpty(x.dtype)(1) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif xb[0] and not yb[0] and not yb[1]: - # x is vector, y is matrix so try gemv - xv = x.dimshuffle(1) - zeros = at.AllocEmpty(x.dtype)(y.shape[1]) - rval = gemv_no_inplace(zeros, one, y.T, xv, zero) - new_out = [rval.dimshuffle("x", 0)] - elif not xb[0] and not xb[1] and yb[1]: - # x is matrix, y is vector, try gemv - yv = y.dimshuffle(0) - zeros = at.AllocEmpty(x.dtype)(x.shape[0]) - rval = gemv_no_inplace(zeros, one, x, yv, zero) - new_out = [rval.dimshuffle(0, "x")] - else: - return - copy_stack_trace(node.outputs, new_out) - return new_out - - -################################# -# -# Set up the BlasOpt optimizer -# -################################# - -blas_optdb = SequenceDB() - -# run after numerical stability optimizations (1.5) -optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7) -# run before specialize (2.0) because specialize is basically a -# free-for-all that makes the graph crazy. - -# fast_compile is needed to have GpuDot22 created. -blas_optdb.register( - "local_dot_to_dot22", - in2out(local_dot_to_dot22), - "fast_run", - "fast_compile", - position=0, -) -blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) -blas_optdb.register( - "local_gemm_to_gemv", - EquilibriumGraphRewriter( - [ - local_gemm_to_gemv, - local_gemm_to_ger, - local_dot22_to_ger_or_gemv, - local_dimshuffle_lift, - ], - max_use_ratio=5, - ignore_newtrees=False, - ), - "fast_run", - position=15, -) - - -# After destroyhandler(49.5) but before we try to make elemwise things -# inplace (75) -blas_opt_inplace = in2out( - local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" -) -optdb.register( - "InplaceBlasOpt", - blas_opt_inplace, - "fast_run", - "inplace", - "blas_opt_inplace", - position=70.0, -) - - class Dot22Scalar(GemmRelated): """Compute a matrix-matrix product. @@ -2049,133 +1636,6 @@ def c_code_cache_version(self): _dot22scalar = Dot22Scalar() -@node_rewriter([mul]) -def local_dot22_to_dot22scalar(fgraph, node): - """ - Notes - ----- - Previous attempts to alter this optimization to replace dot22 with - gemm instead of dot22scalar resulted in some Scan nodes being - duplicated and the ScanSaveMem optimization never running on them, - resulting in highly increased memory usage. Until this issue is - resolved, this optimization should keep using dot22scalar instead of - gemm. - - We upcast the scalar if after the multiplication with the dot this give - the same type. - - We execute this optimizer after the gemm optimizer. This - allow to give more priority to gemm that give more speed up - then this optimizer, but allow the gemm optimizer to ignore - this op. - - TODO: support when we can reorder the mul to generate a - dot22scalar or fix the canonizer to merge them(1 mul with multiple - inputs) - - """ - if node.op != mul: - return False - i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] - if not any(i_dot22): - return False # no dot22 - if i_dot22.count(True) > 1: - # TODO: try each of them. - pass - # return False #TODO fix - dot22_idx = i_dot22.index(True) - d = node.inputs[dot22_idx] - i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] - if not any(i_scalar): - # Check if we can reorder the graph as this mul have a mul in inputs. - # We support only 1 additional level of mul. - # The canonizer should have merged those mul together. - i_mul = [ - x.owner - and x.owner.op == mul - and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs) - for x in node.inputs - ] - if not any(i_mul): - # no scalar in input and no multiplication - # if their was a multiplication we couls reorder the graph - # by the associativity of the graph. - return False - - mul_idx = i_mul.index(True) # The first one should always work - m = node.inputs[mul_idx] - - scalar_idx = -1 - for i, x in enumerate(m.owner.inputs): - if _as_scalar(x, dtype=d.dtype) and ( - pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype - ): - scalar_idx = i - break - - if scalar_idx < 0: - _logger.info( - f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the" - " type of the scalar cannot be upcasted to the" - " matrix type" - ) - return False - a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) - assert not a.type.ndim - dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) - - # The other inputs to the original node that were - # neither part of the dot22 or this mul should be - # factors in the returned "mul" node. - assert dot22_idx != mul_idx - other_factors = [ - inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx) - ] - other_m_inputs = [ - inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx - ] - - return [mul(dot, *(other_factors + other_m_inputs))] - - scalar_idx = -1 - for i, x in enumerate(node.inputs): - if ( - i != dot22_idx - and i_scalar[i] is not None - and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype) - ): - scalar_idx = i - break - if scalar_idx < 0: - _logger.info( - f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type " - "of the scalar cannot be upcasted to the matrix type" - ) - return False - assert scalar_idx < len(node.inputs) - s = node.inputs[scalar_idx] - o = copy.copy(node.inputs) - o.remove(d) - o.remove(s) - - a = at.cast(i_scalar[scalar_idx], d.type.dtype) - assert not a.type.ndim - if len(o) == 0: - return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] - else: - return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] - - -# must happen after gemm as the gemm optimizer don't understant -# dot22scalar and gemm give more speed up then dot22scalar -blas_optdb.register( - "local_dot22_to_dot22scalar", - in2out(local_dot22_to_dot22scalar), - "fast_run", - position=11, -) - - class BatchedDot(COp): """ Computes the batched dot product of two variables: @@ -2669,14 +2129,6 @@ def infer_shape(self, fgraph, node, shapes): _batched_dot = BatchedDot() -# from opt import register_specialize, register_canonicalize -# @register_specialize -@node_rewriter([sub, add]) -def local_print_as_we_go_along(fgraph, node): - if node.op in (sub, add): - debugprint(node) - - def batched_dot(a, b): """Compute the batched dot product of two variables. diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index e4e90066b0..704970b5ef 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -1,22 +1,12 @@ -from pytensor.configdefaults import config -from pytensor.graph.rewriting.basic import in2out from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.scalar import bool as bool_t -from pytensor.tensor import basic as at from pytensor.tensor.blas import ( Gemv, Ger, blas_header_text, blas_header_version, - blas_optdb, - gemv_inplace, - gemv_no_inplace, - ger, - ger_destructive, ldflags, - node_rewriter, - optdb, ) @@ -344,23 +334,6 @@ def c_code_cache_version(self): cger_no_inplace = CGer(False) -@node_rewriter([ger, ger_destructive]) -def use_c_ger(fgraph, node): - if not config.blas__ldflags: - return - # Only float32 and float64 are supported for now. - if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): - return [CGer(False)(*node.inputs)] - if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): - return [CGer(True)(*node.inputs)] - - -@node_rewriter([CGer(False)]) -def make_c_ger_destructive(fgraph, node): - if isinstance(node.op, CGer) and not node.op.destructive: - return [cger_inplace(*node.inputs)] - - # ##### ####### ####### # GEMV # ##### ####### ####### @@ -697,48 +670,3 @@ def check_force_gemv_init(): check_force_gemv_init._force_init_beta = None - - -@node_rewriter([gemv_inplace, gemv_no_inplace]) -def use_c_gemv(fgraph, node): - if not config.blas__ldflags: - return - # Only float32 and float64 are supported for now. - if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): - return [cgemv_no_inplace(*node.inputs)] - if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): - return [cgemv_inplace(*node.inputs)] - - -@node_rewriter([CGemv(inplace=False)]) -def make_c_gemv_destructive(fgraph, node): - if isinstance(node.op, CGemv) and not node.op.inplace: - inputs = list(node.inputs) - dest = inputs[0] - if ( - dest.owner - and isinstance(dest.owner.op, at.AllocEmpty) - and len(fgraph.clients[dest]) > 1 - ): - inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) - - return [cgemv_inplace(*inputs)] - - -# ##### ####### ####### -# Optimizers -# ##### ####### ####### - -blas_optdb.register( - "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 -) - -# this matches the InplaceBlasOpt defined in blas.py -optdb.register( - "c_blas_destructive", - in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), - "fast_run", - "inplace", - "c_blas", - position=70.0, -) diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py index 4d1be6e322..527d5150a1 100644 --- a/pytensor/tensor/blas_scipy.py +++ b/pytensor/tensor/blas_scipy.py @@ -4,16 +4,7 @@ import numpy as np -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.blas import ( - Ger, - blas_optdb, - ger, - ger_destructive, - have_fblas, - node_rewriter, - optdb, -) +from pytensor.tensor.blas import Ger, have_fblas if have_fblas: @@ -56,36 +47,3 @@ def perform(self, node, inputs, output_storage): scipy_ger_no_inplace = ScipyGer(False) scipy_ger_inplace = ScipyGer(True) - - -@node_rewriter([ger, ger_destructive]) -def use_scipy_ger(fgraph, node): - if node.op == ger: - return [scipy_ger_no_inplace(*node.inputs)] - - -@node_rewriter([scipy_ger_no_inplace]) -def make_ger_destructive(fgraph, node): - if node.op == scipy_ger_no_inplace: - return [scipy_ger_inplace(*node.inputs)] - - -use_scipy_blas = in2out(use_scipy_ger) -make_scipy_blas_destructive = in2out(make_ger_destructive) - -if have_fblas: - # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof - # sucks, but it is almost always present. - # C implementations should be scheduled earlier than this, so that they take - # precedence. Once the original Ger is replaced, then these optimizations - # have no effect. - blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) - - # this matches the InplaceBlasOpt defined in blas.py - optdb.register( - "make_scipy_blas_destructive", - make_scipy_blas_destructive, - "fast_run", - "inplace", - position=70.0, - ) diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index cb244afb7e..80946d524c 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -1,9 +1,13 @@ import pytensor.tensor.rewriting.basic +import pytensor.tensor.rewriting.blas +import pytensor.tensor.rewriting.blas_c +import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops # Register JAX specializations import pytensor.tensor.rewriting.jax +import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py new file mode 100644 index 0000000000..a310cb5837 --- /dev/null +++ b/pytensor/tensor/rewriting/blas.py @@ -0,0 +1,907 @@ +"""optimizations for using BLAS calls + +Optimizations +============= + +The optimization pipeline works something like this: + + 1. identify dot22 from dot + 2. identify gemm from dot22 + 3. identify dot22scalar from dot22 that are not gemm + 4. specialize gemm to gemv where applicable + 5. specialize gemm to ger where applicable + 6. specialize dot22 -> gemv or ger where applicable + +:note: GEMM is the most canonical BLAS signature that we deal with so far, it + would be good to turn most things into GEMM (dot, inner, outer, dot22, + dot22scalar), and then to specialize from gemm to the various other L2 and + L3 operations. + +Identify Dot22 +-------------- + +Numpy's dot supports arguments that are of any rank, and we should support that +too (just for compatibility). The BLAS optimizations work with Dot Ops whose +inputs are each either vector or matrix. So the first part of the optimization +pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be +transformed further, but they will get implemented by a BLAS call. + +More precisely, Dot nodes whose inputs are all vectors or matrices and whose +inputs both have the same dtype, and whose dtype is float or complex, become +Dot22. This is implemented in `local_dot_to_dot22`. + + +Identify Gemm from Dot22 +------------------------ + +This is complicated, done in GemmOptimizer. + +Identify Dot22Scalar from Dot22 +------------------------------- + +Dot22 Ops that remain after the GemmOptimizer is done have not +qualified as GEMM Ops. Still they might be scaled by a factor, in +which case we use Dot22Scalar which is like Gemm, but without the b +and the Z. In the future it would be good to merge this into the +GemmOptimizer. + +Specialize Gemm to Gemv +----------------------- + +If arguments to GEMM are dimshuffled vectors, then we can use GEMV +instead. This optimization is `local_gemm_to_gemv`. + +""" + +import copy +import logging +import time + +import numpy as np + + +try: + import numpy.__config__ # noqa +except ImportError: + pass + + +import pytensor.scalar +from pytensor.compile.mode import optdb +from pytensor.configdefaults import config +from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate +from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, + GraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from pytensor.graph.rewriting.db import SequenceDB +from pytensor.graph.utils import InconsistencyError +from pytensor.printing import debugprint +from pytensor.tensor import basic as at +from pytensor.tensor.blas import ( + Dot22, + _dot22, + _dot22scalar, + gemm_inplace, + gemm_no_inplace, + gemv_inplace, + gemv_no_inplace, + ger, + ger_destructive, +) +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import Dot, add, mul, neg, sub +from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.type import ( + DenseTensorType, + TensorType, + integer_dtypes, + values_eq_approx_remove_inf_nan, +) + + +_logger = logging.getLogger("pytensor.tensor.rewriting.blas") + + +def res_is_a(fgraph, var, op, maxclients=None): + if maxclients is not None and var in fgraph.clients: + retval = len(fgraph.get_clients(var)) <= maxclients + else: + retval = True + + return var.owner and var.owner.op == op and retval + + +def _as_scalar(res, dtype=None): + """Return ``None`` or a `TensorVariable` of float type""" + if dtype is None: + dtype = config.floatX + if all(s == 1 for s in res.type.shape): + while res.owner and isinstance(res.owner.op, DimShuffle): + res = res.owner.inputs[0] + # may still have some number of True's + if res.type.ndim > 0: + rval = res.dimshuffle() + else: + rval = res + if rval.type.dtype in integer_dtypes: + # We check that the upcast of res and dtype won't change dtype. + # If dtype is float64, we will cast int64 to float64. + # This is valid when res is a scalar used as input to a dot22 + # as the cast of the scalar can be done before or after the dot22 + # and this will give the same result. + if pytensor.scalar.upcast(res.dtype, dtype) == dtype: + return at.cast(rval, dtype) + else: + return None + + return rval + + +def _is_real_matrix(res): + return ( + res.type.dtype in ("float16", "float32", "float64") + and res.type.ndim == 2 + and res.type.shape[0] != 1 + and res.type.shape[1] != 1 + ) # cope with tuple vs. list + + +def _is_real_vector(res): + return ( + res.type.dtype in ("float16", "float32", "float64") + and res.type.ndim == 1 + and res.type.shape[0] != 1 + ) + + +def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): + # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip + # EXPRESSION: (beta * L) + (alpha * M) + + # we've already checked the client counts, now just make the type check. + # if res_is_a(M, _dot22, 1): + if M.owner and M.owner.op == _dot22: + Ml, Mr = M.owner.inputs + rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] + return rval, M + + # it also might be the case that there is a dimshuffle between the + + # and the dot22. local_dot_to_dot22 in particular will put in such things. + if ( + M.owner + and isinstance(M.owner.op, DimShuffle) + and M.owner.inputs[0].owner + and isinstance(M.owner.inputs[0].owner.op, Dot22) + ): + MM = M.owner.inputs[0] + if M.owner.op.new_order == (0,): + # it is making a column MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) + rval = [g.dimshuffle(0)] + return rval, MM + if M.owner.op.new_order == (1,): + # it is making a row MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) + rval = [g.dimshuffle(1)] + return rval, MM + if len(M.owner.op.new_order) == 0: + # it is making a row MM into a vector + MMl, MMr = MM.owner.inputs + g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) + rval = [g.dimshuffle()] + return rval, MM + + if recurse_flip: + return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) + else: + return False, False + + +def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): + # Tries to interpret node as a sum of scalars * (vectors or matrices) + def scaled(thing): + if scale == 1: + return thing + if scale == -1 and thing.type.dtype != "bool": + return -thing + else: + return scale * thing + + if not isinstance(r.type, TensorType): + return None + + if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( + "float16", + "float32", + "float64", + "complex64", + "complex128", + ): + rval.append(scaled(r)) + return rval + + if maxclients and len(fgraph.clients[r]) > maxclients: + rval.append((scale, r)) + return rval + + if r.owner and r.owner.op == sub: + _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) + _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) + + elif r.owner and r.owner.op == add: + for i in r.owner.inputs: + _gemm_canonicalize(fgraph, i, scale, rval, 1) + + elif r.owner and r.owner.op == neg: + _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) + + elif r.owner and r.owner.op == mul: + scalars = [] + vectors = [] + matrices = [] + for i in r.owner.inputs: + if all(s == 1 for s in i.type.shape): + while i.owner and isinstance(i.owner.op, DimShuffle): + i = i.owner.inputs[0] + if i.type.ndim > 0: + scalars.append(i.dimshuffle()) + else: + scalars.append(i) + elif _is_real_vector(i): + vectors.append(i) + elif _is_real_matrix(i): + matrices.append(i) + else: + # just put the original arguments as in the base case + rval.append((scale, r)) + return rval + if len(matrices) == 1: + assert len(vectors) == 0 + m = matrices[0] + if len(scalars) == 0: + _gemm_canonicalize(fgraph, m, scale, rval, 1) + elif len(scalars) == 1: + _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) + else: + _gemm_canonicalize( + fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 + ) + elif len(vectors) == 1: + assert len(matrices) == 0 + v = vectors[0] + if len(scalars) == 0: + _gemm_canonicalize(fgraph, v, scale, rval, 1) + elif len(scalars) == 1: + _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) + else: + _gemm_canonicalize( + fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 + ) + else: # lets not open this up + rval.append((scale, r)) + else: + rval.append((scale, r)) + return rval + + +def _factor_canonicalized(lst): + # remove duplicates from canonicalized list + + # we only delete out of the right end of the list, + # once i has touched a list element, it is permantent + lst = list(lst) + # print 'FACTOR', lst + # for t in lst: + # if not isinstance(t, (list, tuple)): + # t = (t,) + # for e in t: + # try: + # pytensor.printing.debugprint(e) + # except TypeError: + # print e, type(e) + i = 0 + while i < len(lst) - 1: + try: + s_i, M_i = lst[i] + except Exception: + i += 1 + continue + + j = i + 1 + while j < len(lst): + try: + s_j, M_j = lst[j] + except Exception: + j += 1 + continue + + if M_i is M_j: + s_i = s_i + s_j + lst[i] = (s_i, M_i) + del lst[j] + else: + j += 1 + i += 1 + return lst + + +def _gemm_from_factored_list(fgraph, lst): + """ + Returns None, or a list to replace node.outputs. + + """ + lst2 = [] + # Remove the tuple that can't be cast correctly. + # This can happen when we try to cast a complex to a real + for sM in lst: + # Make every pair in list have matching dtypes + # sM can be a tuple of 2 elements or an PyTensor variable. + if isinstance(sM, tuple): + sm0, sm1 = sM + sm0 = at.as_tensor_variable(sm0) + if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: + lst2.append((at.cast(sm0, sm1.dtype), sM[1])) + + lst = lst2 + + def item_to_var(t): + try: + s, M = t + except Exception: + return t + if s == 1: + return M + if s == -1: + return -M + return s * M + + # Try every pair in the sM_list, trying to turn it into a gemm operation + for i in range(len(lst) - 1): + s_i, M_i = lst[i] + + for j in range(i + 1, len(lst)): + s_j, M_j = lst[j] + + if not M_j.type.in_same_class(M_i.type): + continue + + # print 'TRYING', (s_i, M_i, s_j, M_j) + + gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( + fgraph, s_i, M_i, s_j, M_j + ) + # print 'GOT IT', gemm_of_sM_list + if gemm_of_sM_list: + assert len(gemm_of_sM_list) == 1 + add_inputs = [ + item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) + ] + add_inputs.extend(gemm_of_sM_list) + if len(add_inputs) > 1: + rval = [add(*add_inputs)] + else: + rval = add_inputs + # print "RETURNING GEMM THING", rval + return rval, old_dot22 + + +def _gemm_from_node2(fgraph, node): + """ + + TODO: In many expressions, there are many ways to turn it into a + gemm. For example dot(a,b) + c + d. This function should return all + of them, so that if one version of gemm causes a cycle in the graph, then + another application of gemm can be tried. + + """ + lst = [] + t0 = time.perf_counter() + _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) + t1 = time.perf_counter() + + if len(lst) > 1: + lst = _factor_canonicalized(lst) + t2 = time.perf_counter() + rval = _gemm_from_factored_list(fgraph, lst) + t3 = time.perf_counter() + + # It can happen that _factor_canonicalized and + # _gemm_from_factored_list return a node with an incorrect + # type. This happens in particular when one of the scalar + # factors forces the upcast of the whole expression. In that + # case, we simply skip that candidate for Gemm. This was + # discussed in + # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, + # but never made it into a trac ticket. + + if rval and rval[0][0].type.in_same_class(node.outputs[0].type): + return rval, t1 - t0, t2 - t1, t3 - t2 + + return None, t1 - t0, 0, 0 + + +class GemmOptimizer(GraphRewriter): + """Graph optimizer for inserting Gemm operations.""" + + def __init__(self): + super().__init__() + self.warned = False + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + did_something = True + nb_iter = 0 + nb_replacement = 0 + nb_replacement_didn_t_remove = 0 + nb_inconsistency_make = 0 + nb_inconsistency_replace = 0 + time_canonicalize = 0 + time_factor_can = 0 + time_factor_list = 0 + time_toposort = 0 + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time + + def on_import(new_node): + if new_node is not node: + nodelist.append(new_node) + + u = pytensor.graph.rewriting.basic.DispatchingFeature( + on_import, None, None, name="GemmOptimizer" + ) + fgraph.attach_feature(u) + while did_something: + nb_iter += 1 + t0 = time.perf_counter() + nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs) + time_toposort += time.perf_counter() - t0 + did_something = False + nodelist.reverse() + for node in nodelist: + if not ( + isinstance(node.op, Elemwise) + and isinstance( + node.op.scalar_op, + ( + pytensor.scalar.Add, + pytensor.scalar.Sub, + pytensor.scalar.Neg, + pytensor.scalar.Mul, + ), + ) + ): + continue + if node not in fgraph.apply_nodes: + # This mean that we already removed this node from + # the graph + continue + try: + new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node) + time_canonicalize += time1 + time_factor_can += time2 + time_factor_list += time3 + except InconsistencyError: + nb_inconsistency_make += 1 + continue + if new_outputs: + new_outputs, old_dot22 = new_outputs + assert len(new_outputs) == len(node.outputs) + new_outputs[ + 0 + ].tag.values_eq_approx = values_eq_approx_remove_inf_nan + try: + fgraph.replace_all_validate_remove( + list(zip(node.outputs, new_outputs)), + [old_dot22], + reason="GemmOptimizer", + # For now we disable the warning as we know case + # that we need to fix. + warn=False, # warn=not self.warned + ) + did_something = True + nb_replacement += 1 + except InconsistencyError: + # TODO: retry other applications of gemm (see comment + # in _gemm_from_node) + nb_inconsistency_replace += 1 + except ReplacementDidNotRemoveError: + nb_replacement_didn_t_remove += 1 + self.warned = True + fgraph.remove_feature(u) + if fgraph.profile: + validate_time = fgraph.profile.validate_time - validate_before + callback_time = fgraph.execute_callbacks_time - callback_before + callbacks_time = {} + for k, v in fgraph.execute_callbacks_times.items(): + if k in callbacks_before: + callbacks_time[k] = v - callbacks_before[k] + else: + callbacks_time[k] = v + else: + validate_time = None + callback_time = None + callbacks_time = {} + + return ( + self, + nb_iter, + nb_replacement, + nb_replacement_didn_t_remove, + nb_inconsistency_make, + nb_inconsistency_replace, + time_canonicalize, + time_factor_can, + time_factor_list, + time_toposort, + validate_time, + callback_time, + callbacks_time, + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + print(blanc, cls.__name__, file=stream) + print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_replacement", prof[2], file=stream) + print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream) + print(blanc, " nb_inconsistency_make", prof[4], file=stream) + print(blanc, " nb_inconsistency_replace", prof[5], file=stream) + print(blanc, " time_canonicalize", prof[6], file=stream) + print(blanc, " time_factor_can", prof[7], file=stream) + print(blanc, " time_factor_list", prof[8], file=stream) + print(blanc, " time_toposort", prof[9], file=stream) + print(blanc, " validate_time", prof[10], file=stream) + print(blanc, " callback_time", prof[11], file=stream) + if prof[11] > 1: + print(blanc, " callbacks_time", file=stream) + for i in sorted(prof[12].items(), key=lambda a: a[1]): + if i[1] > 0: + print(i) + + +@node_rewriter([Dot]) +def local_dot_to_dot22(fgraph, node): + # This works for tensor.outer too because basic.outer is a macro that + # produces a dot(dimshuffle,dimshuffle) of form 4 below + if not isinstance(node.op, Dot): + return + + if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): + return False + + x, y = node.inputs + if y.type.dtype != x.type.dtype: + # TODO: upcast one so the types match + _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") + return + + if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"): + if x.ndim == 2 and y.ndim == 2: + new_out = [_dot22(*node.inputs)] + elif x.ndim == 2 and y.ndim == 1: + new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)] + elif x.ndim == 1 and y.ndim == 2: + new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)] + elif x.ndim == 1 and y.ndim == 1: + new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") + + +@node_rewriter([gemm_no_inplace], inplace=True) +def local_inplace_gemm(fgraph, node): + if node.op == gemm_no_inplace: + new_out = [gemm_inplace(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemv_no_inplace], inplace=True) +def local_inplace_gemv(fgraph, node): + if node.op == gemv_no_inplace: + new_out = [gemv_inplace(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([ger], inplace=True) +def local_inplace_ger(fgraph, node): + if node.op == ger: + new_out = [ger_destructive(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemm_no_inplace]) +def local_gemm_to_gemv(fgraph, node): + """GEMM acting on row or column matrices -> GEMV.""" + if node.op == gemm_no_inplace: + z, a, x, y, b = node.inputs + if z.broadcastable == x.broadcastable == (True, False): + r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b) + new_out = [r.dimshuffle("x", 0)] + elif z.broadcastable == y.broadcastable == (False, True): + r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) + new_out = [r.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +@node_rewriter([gemm_no_inplace]) +def local_gemm_to_ger(fgraph, node): + """GEMM computing an outer-product -> GER.""" + if node.op == gemm_no_inplace: + z, a, x, y, b = node.inputs + if x.broadcastable[1] and y.broadcastable[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + try: + bval = at.get_underlying_scalar_constant_value(b) + except NotScalarConstantError: + # b isn't a constant, GEMM is doing useful pre-scaling + return + + if bval == 1: # best case a natural GER + rval = ger(z, a, xv, yv) + new_out = [rval] + elif bval == 0: # GER on zeros_like should be faster than GEMM + zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype) + rval = ger(zeros, a, xv, yv) + new_out = [rval] + else: + # if bval is another constant, then z is being usefully + # pre-scaled and GER isn't really the right tool for the job. + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline +# working +@node_rewriter([_dot22]) +def local_dot22_to_ger_or_gemv(fgraph, node): + """dot22 computing an outer-product -> GER.""" + if node.op == _dot22: + x, y = node.inputs + xb = x.broadcastable + yb = y.broadcastable + one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype)) + zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype)) + if xb[1] and yb[0]: + # x and y are both vectors so this might qualifies for a GER + xv = x.dimshuffle(0) + yv = y.dimshuffle(1) + zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) + rval = ger(zeros, one, xv, yv) + new_out = [rval] + elif xb[0] and yb[1]: + # x and y are both vectors so this qualifies for a sdot / ddot + # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 + xv = x.dimshuffle(1) + zeros = at.AllocEmpty(x.dtype)(1) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif xb[0] and not yb[0] and not yb[1]: + # x is vector, y is matrix so try gemv + xv = x.dimshuffle(1) + zeros = at.AllocEmpty(x.dtype)(y.shape[1]) + rval = gemv_no_inplace(zeros, one, y.T, xv, zero) + new_out = [rval.dimshuffle("x", 0)] + elif not xb[0] and not xb[1] and yb[1]: + # x is matrix, y is vector, try gemv + yv = y.dimshuffle(0) + zeros = at.AllocEmpty(x.dtype)(x.shape[0]) + rval = gemv_no_inplace(zeros, one, x, yv, zero) + new_out = [rval.dimshuffle(0, "x")] + else: + return + copy_stack_trace(node.outputs, new_out) + return new_out + + +################################# +# +# Set up the BlasOpt optimizer +# +################################# + +blas_optdb = SequenceDB() + +# run after numerical stability optimizations (1.5) +optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7) +# run before specialize (2.0) because specialize is basically a +# free-for-all that makes the graph crazy. + +# fast_compile is needed to have GpuDot22 created. +blas_optdb.register( + "local_dot_to_dot22", + in2out(local_dot_to_dot22), + "fast_run", + "fast_compile", + position=0, +) +blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) +blas_optdb.register( + "local_gemm_to_gemv", + EquilibriumGraphRewriter( + [ + local_gemm_to_gemv, + local_gemm_to_ger, + local_dot22_to_ger_or_gemv, + local_dimshuffle_lift, + ], + max_use_ratio=5, + ignore_newtrees=False, + ), + "fast_run", + position=15, +) + + +# After destroyhandler(49.5) but before we try to make elemwise things +# inplace (75) +blas_opt_inplace = in2out( + local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace" +) +optdb.register( + "InplaceBlasOpt", + blas_opt_inplace, + "fast_run", + "inplace", + "blas_opt_inplace", + position=70.0, +) + + +@node_rewriter([mul]) +def local_dot22_to_dot22scalar(fgraph, node): + """ + Notes + ----- + Previous attempts to alter this optimization to replace dot22 with + gemm instead of dot22scalar resulted in some Scan nodes being + duplicated and the ScanSaveMem optimization never running on them, + resulting in highly increased memory usage. Until this issue is + resolved, this optimization should keep using dot22scalar instead of + gemm. + + We upcast the scalar if after the multiplication with the dot this give + the same type. + + We execute this optimizer after the gemm optimizer. This + allow to give more priority to gemm that give more speed up + then this optimizer, but allow the gemm optimizer to ignore + this op. + + TODO: support when we can reorder the mul to generate a + dot22scalar or fix the canonizer to merge them(1 mul with multiple + inputs) + + """ + if node.op != mul: + return False + i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] + if not any(i_dot22): + return False # no dot22 + if i_dot22.count(True) > 1: + # TODO: try each of them. + pass + # return False #TODO fix + dot22_idx = i_dot22.index(True) + d = node.inputs[dot22_idx] + i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] + if not any(i_scalar): + # Check if we can reorder the graph as this mul have a mul in inputs. + # We support only 1 additional level of mul. + # The canonizer should have merged those mul together. + i_mul = [ + x.owner + and x.owner.op == mul + and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs) + for x in node.inputs + ] + if not any(i_mul): + # no scalar in input and no multiplication + # if their was a multiplication we couls reorder the graph + # by the associativity of the graph. + return False + + mul_idx = i_mul.index(True) # The first one should always work + m = node.inputs[mul_idx] + + scalar_idx = -1 + for i, x in enumerate(m.owner.inputs): + if _as_scalar(x, dtype=d.dtype) and ( + pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype + ): + scalar_idx = i + break + + if scalar_idx < 0: + _logger.info( + f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the" + " type of the scalar cannot be upcasted to the" + " matrix type" + ) + return False + a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) + assert not a.type.ndim + dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) + + # The other inputs to the original node that were + # neither part of the dot22 or this mul should be + # factors in the returned "mul" node. + assert dot22_idx != mul_idx + other_factors = [ + inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx) + ] + other_m_inputs = [ + inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx + ] + + return [mul(dot, *(other_factors + other_m_inputs))] + + scalar_idx = -1 + for i, x in enumerate(node.inputs): + if ( + i != dot22_idx + and i_scalar[i] is not None + and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype) + ): + scalar_idx = i + break + if scalar_idx < 0: + _logger.info( + f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type " + "of the scalar cannot be upcasted to the matrix type" + ) + return False + assert scalar_idx < len(node.inputs) + s = node.inputs[scalar_idx] + o = copy.copy(node.inputs) + o.remove(d) + o.remove(s) + + a = at.cast(i_scalar[scalar_idx], d.type.dtype) + assert not a.type.ndim + if len(o) == 0: + return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] + else: + return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] + + +# must happen after gemm as the gemm optimizer don't understant +# dot22scalar and gemm give more speed up then dot22scalar +blas_optdb.register( + "local_dot22_to_dot22scalar", + in2out(local_dot22_to_dot22scalar), + "fast_run", + position=11, +) + + +# from opt import register_specialize, register_canonicalize +# @register_specialize +@node_rewriter([sub, add]) +def local_print_as_we_go_along(fgraph, node): + if node.op in (sub, add): + debugprint(node) diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py new file mode 100644 index 0000000000..77629dccca --- /dev/null +++ b/pytensor/tensor/rewriting/blas_c.py @@ -0,0 +1,70 @@ +from pytensor.configdefaults import config +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor import basic as at +from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive +from pytensor.tensor.blas_c import ( + CGemv, + CGer, + cgemv_inplace, + cgemv_no_inplace, + cger_inplace, +) +from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb + + +@node_rewriter([ger, ger_destructive]) +def use_c_ger(fgraph, node): + if not config.blas__ldflags: + return + # Only float32 and float64 are supported for now. + if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): + return [CGer(False)(*node.inputs)] + if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): + return [CGer(True)(*node.inputs)] + + +@node_rewriter([CGer(False)]) +def make_c_ger_destructive(fgraph, node): + if isinstance(node.op, CGer) and not node.op.destructive: + return [cger_inplace(*node.inputs)] + + +@node_rewriter([gemv_inplace, gemv_no_inplace]) +def use_c_gemv(fgraph, node): + if not config.blas__ldflags: + return + # Only float32 and float64 are supported for now. + if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): + return [cgemv_no_inplace(*node.inputs)] + if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): + return [cgemv_inplace(*node.inputs)] + + +@node_rewriter([CGemv(inplace=False)]) +def make_c_gemv_destructive(fgraph, node): + if isinstance(node.op, CGemv) and not node.op.inplace: + inputs = list(node.inputs) + dest = inputs[0] + if ( + dest.owner + and isinstance(dest.owner.op, at.AllocEmpty) + and len(fgraph.clients[dest]) > 1 + ): + inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) + + return [cgemv_inplace(*inputs)] + + +blas_optdb.register( + "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 +) + +# this matches the InplaceBlasOpt defined in blas.py +optdb.register( + "c_blas_destructive", + in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), + "fast_run", + "inplace", + "c_blas", + position=70.0, +) diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py new file mode 100644 index 0000000000..2b2aa94eef --- /dev/null +++ b/pytensor/tensor/rewriting/blas_scipy.py @@ -0,0 +1,37 @@ +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor.blas import ger, ger_destructive, have_fblas +from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace +from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb + + +@node_rewriter([ger, ger_destructive]) +def use_scipy_ger(fgraph, node): + if node.op == ger: + return [scipy_ger_no_inplace(*node.inputs)] + + +@node_rewriter([scipy_ger_no_inplace]) +def make_ger_destructive(fgraph, node): + if node.op == scipy_ger_no_inplace: + return [scipy_ger_inplace(*node.inputs)] + + +use_scipy_blas = in2out(use_scipy_ger) +make_scipy_blas_destructive = in2out(make_ger_destructive) + +if have_fblas: + # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof + # sucks, but it is almost always present. + # C implementations should be scheduled earlier than this, so that they take + # precedence. Once the original Ger is replaced, then these optimizations + # have no effect. + blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100) + + # this matches the InplaceBlasOpt defined in blas.py + optdb.register( + "make_scipy_blas_destructive", + make_scipy_blas_destructive, + "fast_run", + "inplace", + position=70.0, + ) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index afc51a9e3c..6bf4b5b902 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -349,7 +349,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1): inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) -compile.optdb.register( # type: ignore +compile.optdb.register( "inplace_elemwise_opt", inplace_elemwise_optimizer, "inplace_opt", # for historic reason @@ -1097,7 +1097,7 @@ def print_profile(stream, prof, level=0): "fusion", position=1, ) - compile.optdb.register( # type: ignore + compile.optdb.register( "elemwise_fusion", fuse_seqopt, "fast_run", @@ -1211,7 +1211,7 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] -compile.optdb.register( # type: ignore +compile.optdb.register( "local_careduce_fusion", in2out(local_careduce_fusion), "fusion", @@ -1321,7 +1321,7 @@ def split_2f1grad_loop(fgraph, node): return replacements -compile.optdb["py_only"].register( # type: ignore +compile.optdb["py_only"].register( "split_2f1grad_loop", split_2f1grad_loop, "fast_compile", diff --git a/pytensor/sandbox/linalg/ops.py b/pytensor/tensor/rewriting/linalg.py similarity index 83% rename from pytensor/sandbox/linalg/ops.py rename to pytensor/tensor/rewriting/linalg.py index 0a53924801..8f09e52261 100644 --- a/pytensor/sandbox/linalg/ops.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node): return [x] +@register_canonicalize +@register_stabilize +@node_rewriter([Cholesky]) +def cholesky_ldotlt(fgraph, node): + """ + rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, + or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. + + This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. + """ + if not isinstance(node.op, Cholesky): + return + + A = node.inputs[0] + if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))): + return + + l, r = A.owner.inputs + + # cholesky(dot(L,L.T)) case + if ( + getattr(l.tag, "lower_triangular", False) + and r.owner + and isinstance(r.owner.op, DimShuffle) + and r.owner.op.new_order == (1, 0) + and r.owner.inputs[0] == l + ): + if node.op.lower: + return [l] + return [r] + + # cholesky(dot(U.T,U)) case + if ( + getattr(r.tag, "upper_triangular", False) + and l.owner + and isinstance(l.owner.op, DimShuffle) + and l.owner.op.new_order == (1, 0) + and l.owner.inputs[0] == r + ): + if node.op.lower: + return [l] + return [r] + + @register_stabilize @register_specialize @node_rewriter([Det]) diff --git a/tests/sandbox/__init__.py b/tests/sandbox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/sandbox/linalg/__init__.py b/tests/sandbox/linalg/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/sandbox/test_minimal.py b/tests/sandbox/test_minimal.py deleted file mode 100644 index 82e346eaf3..0000000000 --- a/tests/sandbox/test_minimal.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np -import pytest - -from pytensor import function -from pytensor.sandbox.minimal import minimal -from pytensor.tensor.type import matrix, vector -from tests import unittest_tools as utt - - -@pytest.mark.skip(reason="Unfinished test") -class TestMinimal: - """ - TODO: test dtype conversion - TODO: test that invalid types are rejected by make_node - TODO: test that each valid type for A and b works correctly - """ - - def setup_method(self): - self.rng = np.random.default_rng(utt.fetch_seed(666)) - - def test_minimal(self): - A = matrix() - b = vector() - - print("building function") - f = function([A, b], minimal(A, A, b, b, A)) - print("built") - - Aval = self.rng.standard_normal((5, 5)) - bval = np.arange(5, dtype=float) - f(Aval, bval) - print("done") diff --git a/tests/sandbox/linalg/test_linalg.py b/tests/tensor/rewriting/test_linalg.py similarity index 59% rename from tests/sandbox/linalg/test_linalg.py rename to tests/tensor/rewriting/test_linalg.py index f2cb67221c..9ec182cb21 100644 --- a/tests/sandbox/linalg/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1,14 +1,17 @@ import numpy as np import numpy.linalg +import pytest +import scipy.linalg import pytensor from pytensor import function from pytensor import tensor as at +from pytensor.compile import get_default_mode from pytensor.configdefaults import config -from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse +from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import Cholesky, Solve, solve from pytensor.tensor.type import dmatrix, matrix, vector from tests import unittest_tools as utt @@ -65,53 +68,6 @@ def test_rop_lop(): assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" -def test_spectral_radius_bound(): - tol = 10 ** (-6) - rng = np.random.default_rng(utt.fetch_seed()) - x = matrix() - radius_bound = spectral_radius_bound(x, 5) - f = pytensor.function([x], radius_bound) - - shp = (3, 4) - m = rng.random(shp) - m = np.cov(m).astype(config.floatX) - radius_bound_pytensor = f(m) - - # test the approximation - mm = m - for i in range(5): - mm = np.dot(mm, mm) - radius_bound_numpy = np.trace(mm) ** (2 ** (-5)) - assert abs(radius_bound_numpy - radius_bound_pytensor) < tol - - # test the bound - eigen_val = numpy.linalg.eig(m) - assert (eigen_val[0].max() - radius_bound_pytensor) < tol - - # test type errors - xx = vector() - ok = False - try: - spectral_radius_bound(xx, 5) - except TypeError: - ok = True - assert ok - ok = False - try: - spectral_radius_bound(x, 5.0) - except TypeError: - ok = True - assert ok - - # test value error - ok = False - try: - spectral_radius_bound(x, -5) - except ValueError: - ok = True - assert ok - - def test_transinv_to_invtrans(): X = matrix("X") Y = matrix_inverse(X) @@ -152,3 +108,75 @@ def test_matrix_inverse_solve(): node = matrix_inverse(A).dot(b).owner [out] = inv_as_solve.transform(None, node) assert isinstance(out.owner.op, Solve) + + +@pytest.mark.parametrize("tag", ("lower", "upper", None)) +@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) +@pytest.mark.parametrize("product", ("lower", "upper", None)) +def test_cholesky_ldotlt(tag, cholesky_form, product): + cholesky = Cholesky(lower=(cholesky_form == "lower")) + + transform_removes_chol = tag is not None and product == tag + transform_transposes = transform_removes_chol and cholesky_form != tag + + A = matrix("L") + if tag: + setattr(A.tag, tag + "_triangular", True) + + if product == "lower": + M = A.dot(A.T) + elif product == "upper": + M = A.T.dot(A) + else: + M = A + + C = cholesky(M) + f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) + + print(f.maker.fgraph.apply_nodes) + + no_cholesky_in_graph = not any( + isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes + ) + + assert no_cholesky_in_graph == transform_removes_chol + + if transform_transposes: + assert any( + isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) + for node in f.maker.fgraph.apply_nodes + ) + + # Test some concrete value through f + # there must be lower triangular (f assumes they are) + Avs = [ + np.eye(1, dtype=pytensor.config.floatX), + np.eye(10, dtype=pytensor.config.floatX), + np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX), + ] + if not tag: + # these must be positive def + Avs.extend( + [ + np.ones((4, 4), dtype=pytensor.config.floatX) + + np.eye(4, dtype=pytensor.config.floatX), + ] + ) + + for Av in Avs: + if tag == "upper": + Av = Av.T + + if product == "lower": + Mv = Av.dot(Av.T) + elif product == "upper": + Mv = Av.T.dot(Av) + else: + Mv = Av + + assert np.all( + np.isclose( + scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), + f(Av), + ) + ) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 0ce7640d38..035f9e036b 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -44,12 +44,11 @@ gemv_no_inplace, ger, ger_destructive, - local_dot22_to_dot22scalar, - local_gemm_to_ger, res_is_a, ) from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt +from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( cmatrix, col,