
Xc           @` sU  d  Z  d d l m Z m Z m Z d d l Z d d l Z d d l Z d d l Z d d l	 Z	 d d l
 Z	 y d d l Z	 Wn e k
 r n Xd d l m Z d d l m Z m Z d d l m Z d d l m Z m Z m Z m Z m Z m Z m Z m Z m Z m Z m Z d d l  m! Z! m" Z" m# Z# d d	 l$ m% Z% d d l& Z d d
 l' m( Z) d d l* m+ Z+ d d l* m, Z, d d l- m. Z. m/ Z/ d d l0 m1 Z1 e j2 d  Z3 y d d l4 Z5 e6 Z7 y e5 j8 j9 j: Z: Wn e; k
 re5 j8 j9 Z: n Xi e: j< e	 j= d  6e: j> e	 j= d  6e: j? e	 j= d  6e: j@ e	 j= d  6ZA Wn? e k
 rZB eC Z7 e j9 jD re3 jE d eF eB   n  n Xd   ZG eH eG _I d e f d     YZJ eJ d eC  ZK eJ d e6  ZL eK ZM d e f d     YZN eN d eC  ZO eN d e6  ZP e6 eC eC eC d  ZD e jQ d    ZR d e f d     YZS d  eS f d!     YZT eT d e6  ZU eT d eC  ZV eV ZW e! jX eU e" d"   e! jX eV e" d#   eH d$  ZY eH d%  ZZ d&   Z[ d'   Z\ e6 d(  Z] d)   Z^ d*   Z_ d+   Z` d,   Za d- e f d.     YZb d/ eS f d0     YZc ec   Zd e e) je g  d1    Zf e eV g d e6 d2    Zg e eK g d e6 d3    Zh e eO g d e6 d4    Zi e eV g  d5    Zj e eV g  d6    Zk e ed g  d7    Zl e   Zm e% jn d8 em d9 d: d;  em jn d< e. ef  d d: d;  em jn d= eb   d> d:  em jn d? e ej ek el e/ g d@ dA dB eC dC d:  e. eg eh ei dD dE Zo e% jn dF eo dG d: d dE  dH eS f dI     YZp ep   Zq e e) jr g  dJ    Zs em jn dK e. es  dL d:  dM e f dN     YZt et   Zu e e) jv e) jw g  dO    Zx d S(P   s  Ops and optimizations for using BLAS calls

BLAS = Basic Linear Algebra Subroutines
Learn more about BLAS here:
    http://www.netlib.org/blas/blast-forum/
The standard BLAS libraries implement what is called "legacy BLAS" in that
document.

This documentation describes Theano's BLAS optimization pipeline.

Where there is a discrepancy between how things do work and how they *should*
work, both aspects should be documented.

There are four kinds of BLAS Ops in Theano:
    - Python implementations (this file)
    - SciPy-based (blas_scipy)
    - C-based (blas_c)
    - CUDA-based (theano.sandbox.cuda.blas)

Notes
-----
Unfortunately (because it's confusing) this file currently contains Ops
that contain both Python and C versions.  I think it would be better to
move the C implementations to blas_c so that this file is pure Python.
-JB


Ops
===

GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm
-------------------------------------------

The BLAS GEMM operation implements Z <- a X Y + b Z,
where Z, X and Y are matrices, and a and b are scalars.

Dot22 is a GEMM where a=1, b=0, and Z is allocated every time.

Dot22Scalar is a GEMM where b=0 and Z is allocated every time.

Gemm is a GEMM in all its generality.

In the future we can refactor the GemmRelated, Gemm, Dot22 and
Dot22Scalar Ops into a single Op.  That new Op (Gemm2) is basically a
normal Gemm, but with an additional configuration variable that says
to ignore the input Z.  Setting that configuration variable to True
would make Gemm2 equivalent to the current Dot22 and Dot22Scalar.
This would make the file a lot easier to read, and save a few hundred
lines of library, to say nothing of testing and documentation.


GEMV: Gemv
----------

The BLAS GEMV operation implements Z <- a X Y + b Z,
where X is a matrix, Y, and Z are vectors, and a and b are scalars.


GER: Ger
--------

The BLAS GER operation implements Z <- a X' Y + Z,
where X and Y are vectors, and matrix Z gets a rank-1 update.


Other Notable BLAS-related Ops
------------------------------

SYRK is another useful special case of GEMM. Particularly SYRK preserves
symmetry in the matrix that it updates.  See how the linear-algebra module uses
symmetry hints before implementing this Op, so that this Op is compatible with
that system.


Optimizations
=============

The optimization pipeline works something like this:

    1. identify dot22 from dot
    2. identify gemm from dot22
    3. identify dot22scalar from dot22 that are not gemm
    4. specialize gemm to gemv where applicable
    5. specialize gemm to ger where applicable
    6. specialize dot22 -> gemv or ger where applicable

:note: GEMM is the most canonical BLAS signature that we deal with so far, it
    would be good to turn most things into GEMM (dot, inner, outer, dot22,
    dot22scalar), and then to specialize from gemm to the various other L2 and
    L3 operations.

Identify Dot22
--------------

Numpy's dot supports arguments that are of any rank, and we should support that
too (just for compatibility).  The BLAS optimizations work with Dot Ops whose
inputs are each either vector or matrix.  So the first part of the optimization
pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be
transformed further, but they will get implemented by a BLAS call.

More precisely, Dot nodes whose inputs are all vectors or matrices and whose
inputs both have the same dtype, and whose dtype is float or complex, become
Dot22.  This is implemented in `local_dot_to_dot22`.


Identify Gemm from Dot22
------------------------

This is complicated, done in GemmOptimizer.

Identify Dot22Scalar from Dot22
-------------------------------

Dot22 Ops that remain after the GemmOptimizer is done have not
qualified as GEMM Ops. Still they might be scaled by a factor, in
which case we use Dot22Scalar which is like Gemm, but without the b
and the Z.  In the future it would be good to merge this into the
GemmOptimizer.

Specialize Gemm to Gemv
-----------------------

If arguments to GEMM are dimshuffled vectors, then we can use GEMV
instead. This optimization is `local_gemm_to_gemv`.

i    (   t   absolute_importt   print_functiont   divisionN(   t	   iteritems(   t   reducet   xrange(   t   config(   t   utilst   Opt
   view_rootst   local_optimizert	   Optimizert   InconsistencyErrort   toolboxt
   SequenceDBt   EquilibriumOptimizert   Applyt   ReplacementDidntRemovedError(   t   pprintt   FunctionPrintert
   debugprint(   t   optdb(   t   basic(   t   blas_header_text(   t   blas_header_version(   t   in2outt   local_dimshuffle_lift(   t   values_eq_approx_remove_inf_nans   theano.tensor.blast   float32t   float64t	   complex64t
   complex128s   Failed to import scipy.linalg.blas, and Theano flag blas.ldflags is empty. Falling back on slower implementations for dot(matrix, vector), dot(vector, matrix) and dot(vector, vector) (%s)c       
   C` s   t  j d  k r t s! t t  _ n  t d  t j d  }  t j d  } t j d	  } t |  j	 } | d | j
 | d |  d t d t t j |   j   t  _ n  t  j S(
   Nt   NaNi   g      ?g        t   overwrite_yt   trans(   i   (   i   (   i   i   (   t   check_init_yt   _resultt   Nonet
   have_fblast   Falset   floatt   numpyt   onest   _blas_gemv_fnst   dtypet   Tt   Truet   isnant   any(   t   yt   xt   At   gemv(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR#      s    %t   Gemvc           B` sA   e  Z d  Z d Z d   Z d   Z d   Z d   Z d   Z RS(   s   
    expression is beta * y + alpha * A x

    A is matrix
    x, y are vectors
    alpha, beta are scalars
    output is a vector that can be inplace on y

    t   inplacec         C` s)   | |  _  | r% i d g d 6|  _ n  d  S(   Ni    (   R6   t   destroy_map(   t   selfR6   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   __init__   s    	c         C` s)   |  j  r d |  j j Sd |  j j Sd  S(   Ns   %s{inplace}s   %s{no_inplace}(   R6   t	   __class__t   __name__(   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   __str__   s    	c         C` s'  t  j |  } t  j |  } t  j |  } t  j |  } t  j |  } | j | j k so | j | j k r t d | j | j | j f   n  | j d k r t d | j   n  | j d k r t d | j   n  | j d k r t d | j   n  t |  | | | | | g | j   g  S(   Ns   Gemv requires matching dtypesi   s   gemv requires matrix for Ai   s   gemv requires vector for xs   gemv requires vector for y(   R-   t   as_tensor_variableR,   t	   TypeErrort   ndimt   typeR   (   R8   R1   t   alphaR3   R2   t   beta(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt	   make_node   s    $c      
   C` s  | \ } } } } } t  r| j d d k r| j d d k r| j t k rt | j }	 | j d | j d k s | j d | j d k r t d | j | j | j f   n  | d k r t   r | j d  n  |	 | | j | | | d |  j d t	 | d d <n t
 j | |  }
 | d k r;|
 | 9}
 n  | d k rq| d k rd|
 | | 7}
 qq|
 | 7}
 n  t
 j |
 d | j | d d <d  S(   Ni    i   sQ   Incompatible shapes for gemv (beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s R!   R"   R,   (   R&   t   shapeR,   R+   t
   ValueErrorR#   t   fillR-   R6   R.   R)   t   dott   asarray(   R8   t   nodet   inputst   out_storageR1   RA   R3   R2   RB   R4   t   out(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   perform   s(    ,4
c         C` s   | d g S(   Ni    (    (   R8   RI   t   input_shapes(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   infer_shape  s    (   s   inplace(	   R;   t
   __module__t   __doc__t	   __props__R9   R<   RC   RM   RO   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR5      s   					$R6   t   Gerc           B` sA   e  Z d  Z d Z d   Z d   Z d   Z d   Z d   Z RS(   s   
    BLAS defines general rank-1 update GER as A <- A + alpha x y'

    for matrix A, scalar alpha, vectors x and y.

    This interface to GER allows non-destructive operation on A via the
    `destructive` argument to the constructor.

    t   destructivec         C` s)   | |  _  | r% i d g d 6|  _ n  d  S(   Ni    (   RT   R7   (   R8   RT   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR9   0  s    	c         C` s)   |  j  r d |  j j Sd |  j j Sd  S(   Ns   %s{destructive}s   %s{non-destructive}(   RT   R:   R;   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR<   5  s    	c         C` so  t  j |  } t  j |  } t  j |  } t  j |  } t t | j | j | j | j g   d k r t d | j | j | j | j f   n  | j d k r t d | j   n  | j d k r t d | j   n  | j d k rt d | j   n  | j d k r&t d | j   n  | j d k rJt d | j   n  t |  | | | | g | j   g  S(   Ni   s   ger requires matching dtypesi    s   ger requires scalar alphai   s   ger requires matrix for As   ger requires vector for xs   ger requires vector for yR   R   R   R   s&   only float and complex types supported(   R   R   R   R   (	   R-   R=   t   lent   setR,   R>   R?   R@   R   (   R8   R3   RA   R2   R1   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRC   ;  s$    0$c   
      C` s   | \ } } } } | \ } |  j  r- | }	 n | j   }	 | d k rb |	 | t j | |  7}	 n |	 t j | |  7}	 |	 | d <d  S(   Ni   i    (   RT   t   copyR)   t   outer(
   R8   RI   t   inpRL   t   cAt   calphat   cxt   cyt   cZR3   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRM   P  s    			c         C` s   | d g S(   Ni    (    (   R8   RI   RN   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRO   ]  s    (   s   destructive(	   R;   RP   RQ   RR   R9   R<   RC   RM   RO   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRS   #  s   					RT   c         C` s4   t  j j j } t d | d |  d | d | d |  S(   s  Extract a list of compilation flags from config.blas.ldflags.

    Depending on the options, different type of flags will be kept.
    It returns a list of libraries against which an Op's object file
    should be linked to benefit from a BLAS implementation.

    Parameters
    ----------
    libs : bool, optional
        Extract flags starting with "-l" (the default is True).
    libs_dir : bool, optional
        Extract flags starting with "-L" (the default is False).
    include_dir : bool, optional
        Extract flags starting with "-I" (the default is False).
    flags: bool, optional
        Extract all the other flags (the default is False).

    Returns
    -------
    list of strings
        Extracted flags.

    t   ldflags_strt   libst   flagst   libs_dirt   include_dir(   t   theanoR   t   blast   ldflagst   _ldflags(   R`   Ra   Rb   Rc   R_   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRf   e  s    c         C` s  g  } | r&t  } g  |  j   D] } | j d  r | d ^ q } t d |  d t d t  d t  d t   }	 x | D] }
 x t j |
 j d   D]m } | j d	  s | j d
  s | j d  r t	 g  |	 D] } | j
 |  d k ^ q  r t } q q q Wqr W| r&| r&t j d  q&n  x|  j   D]u} | j d  rW| j d  su| j d  r| j d  r| d d !} n  y, | d d !\ } } } | d k st  Wn' t k
 rt d | |  f   n X| r| d k r| j | d  q3| r6| d k r6t d |   | j | d  q3| r\| d k r\| j | d  q3| r~| d k r~| j |  q3| r3| d k r3| j d | d  q3q3W| S(   sQ  Extract list of compilation flags from a string.

    Depending on the options, different type of flags will be kept.

    Parameters
    ----------
    ldflags_str : string
        The string to process. Typically, this will be the content of
        `theano.config.blas.ldflags`.
    libs : bool
        Extract flags starting with "-l".
    flags: bool
        Extract all the other flags.
    libs_dir: bool
        Extract flags starting with "-L".
    include_dir: bool
        Extract flags starting with "-I".

    Returns
    -------
    list of strings
        Extracted flags.

    s   -Li   R_   R`   Ra   Rb   Rc   t   "s   .sos   .dylibs   .dlli    s   We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.t   'i   ii   t   -s'   invalid token "%s" in ldflags_str: "%s"t   Lt   Isj   Include dirs are not used for blas. We disable this as this can hide other headers and this is not wanted.t   ls   -Wl,-rpath,(   Rk   Rl   Rm   (   R'   t   splitt
   startswithRg   R.   t   ost   listdirt   stript   endswithR0   t   findt   _loggert   warningt   AssertionErrort	   ExceptionRE   t   append(   R_   R`   Ra   Rb   Rc   t   rvalt	   found_dynR2   t   dirsRm   t   dt   ft   llt   tt   t0t   t1t   t2(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRg     sL    .<	t   GemmRelatedc           B` s   e  Z d  Z d Z d   Z d   Z d   Z d   Z d   Z d   Z	 d Z
 d Z d	 Z d
 Z d Z d Z d Z d Z d Z d Z d Z d Z d Z d Z d   Z d   Z RS(   sZ   Base class for Gemm and Dot22.

    This class provides a kind of templated gemm Op.

    c         C` s   d } t    | S(   Ns&  
        #ifndef MOD
        #define MOD %
        #endif
        static double time_time() // a time function like time.time()
        {
            struct timeval tv;
            gettimeofday(&tv, 0);
            return (double) tv.tv_sec + (double) tv.tv_usec / 1000000.0;
        }
        (   R   (   R8   t   mod_str(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_support_code  s    c         C` s   d d d g S(   Ns
   <iostream>s   <time.h>s   <sys/time.h>(    (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt	   c_headers  s    c         C` s   t    S(   N(   Rf   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_libraries  s    c         C` s   t  d t d t  S(   NR`   Ra   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_compile_args  s    c         C` s   t  d t d t  S(   NR`   Rb   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt
   c_lib_dirs  s    c         C` s   t  d t d t  S(   NR`   Rc   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_header_dirs  s    s&  
        int unit = 0;

        int type_num = PyArray_DESCR(%(_x)s)->type_num;
        int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes

        npy_intp* Nx = PyArray_DIMS(%(_x)s);
        npy_intp* Ny = PyArray_DIMS(%(_y)s);
        npy_intp* Nz = 0; //PyArray_DIMS(%(_zout)s);

        npy_intp* Sx = PyArray_STRIDES(%(_x)s);
        npy_intp* Sy = PyArray_STRIDES(%(_y)s);
        npy_intp* Sz = 0; //PyArray_STRIDES(%(_zout)s);

        //strides for x, y, z in dimensions 0, 1
        int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
        s  
        if (PyArray_NDIM(%(_x)s) != 2) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(x) != 2. rank(x) is %%d.",
                         PyArray_NDIM(%(_x)s));
            %(fail)s;
        }
        if (PyArray_NDIM(%(_y)s) != 2) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(y) != 2. rank(y) is %%d.", PyArray_NDIM(%(_y)s));
            %(fail)s;
        }
        if (%(_zout)s && PyArray_NDIM(%(_zout)s) != 2) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(z) != 2. rank(z) is %%d.", PyArray_NDIM(%(_zout)s));
            %(fail)s;
        }
        s  
        if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
            ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num))
        { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
        s  
        if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(%(_b)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_b)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
        s  
        if (Nx[0] != Nz[0])
        {
            PyErr_Format(PyExc_ValueError,
                "Shape mismatch: x has %%ld rows but z has %%ld rows",
                (long int)Nx[0], (long int)Nz[0]);
            %(fail)s;
        }
        if (Nx[1] != Ny[0])
        {
            PyErr_Format(PyExc_ValueError,
                "Shape mismatch: x has %%ld cols (and %%ld rows) but y has %%ld rows (and %%ld cols)",
                (long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
            %(fail)s;
        }
        if (Ny[1] != Nz[1])
        {
            PyErr_Format(PyExc_ValueError,
                "Shape mismatch: y has %%ld cols but z has %%ld cols",
                (long int)Ny[1], (long int)Nz[1]);
            %(fail)s;
        }

        // We must not raise an error when Nx[1] == 0. This would disable cases
        // that numpy.dot accept.
        sx  
        /*
        If some matrices are not contiguous on either dimensions,
        or have invalid strides, copy their content into a contiguous one
        */
        if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
            || ((Sx[0] != type_size) && (Sx[1] != type_size)))
        {
            PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s);
            if (!_x_copy)
                %(fail)s
            Py_XDECREF(%(_x)s);
            %(_x)s = _x_copy;
            Sx = PyArray_STRIDES(%(_x)s);
        }

        if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
            || ((Sy[0] != type_size) && (Sy[1] != type_size)))
        {
            PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s);
            if (!_y_copy)
                %(fail)s
            Py_XDECREF(%(_y)s);
            %(_y)s = _y_copy;
            Sy = PyArray_STRIDES(%(_y)s);
        }

        if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)
            || ((Sz[0] != type_size) && (Sz[1] != type_size)))
        {
            PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s);
            if (!_z_copy)
                %(fail)s
            Py_XDECREF(%(_zout)s);
            %(_zout)s = _z_copy;
            Sz = PyArray_STRIDES(%(_zout)s);
        }
        s  
        /*
        encode the stride structure of _x,_y,_zout into a single integer
        */
        unit |= ((Sx[1] == type_size || Nx[1]==1) ? 0x0 : (Sx[0] == type_size || Nx[0]==1) ? 0x1 : 0x2) << 8;
        unit |= ((Sy[1] == type_size || Ny[1]==1) ? 0x0 : (Sy[0] == type_size || Ny[0]==1) ? 0x1 : 0x2) << 4;
        unit |= ((Sz[1] == type_size || Nz[1]==1) ? 0x0 : (Sz[0] == type_size || Nz[0]==1) ? 0x1 : 0x2) << 0;
        s  
        /* create appropriate strides for malformed matrices that are row or column
         * vectors, or empty matrices.
         * In that case, the value of the stride does not really matter, but
         * some versions of BLAS insist that:
         *  - they are not smaller than the number of elements in the array,
         *  - they are not 0.
         */
        sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : (Nx[1] + 1);
        sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[0] + 1);
        sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : (Ny[1] + 1);
        sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[0] + 1);
        sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : (Nz[1] + 1);
        sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[0] + 1);
        s-   
        switch (type_num)
        {
        s3   
            case NPY_FLOAT:
            {
        sS  
                float* x = (float*)PyArray_DATA(%(_x)s);
                float* y = (float*)PyArray_DATA(%(_y)s);
                float* z = (float*)PyArray_DATA(%(_zout)s);
                char N = 'N';
                char T = 'T';
                int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
                //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\n';
                //double t0 = time_time();
                switch(unit)
                {
                    case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
                    case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
                    case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
                    case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
                    case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
                    case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
                    case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
                    case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
                    default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
                };
                //fprintf(stderr, "Calling sgemm %%i %%i %%i %%i took %%f\n", unit, Nz1, Nz0, Nx1, time_time() - t0);
        sU   
            }
            break;
            case NPY_DOUBLE:
            {
        s  
                double* x = (double*)PyArray_DATA(%(_x)s);
                double* y = (double*)PyArray_DATA(%(_y)s);
                double* z = (double*)PyArray_DATA(%(_zout)s);
                char N = 'N';
                char T = 'T';
                int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
                //std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\n';
                //double t0 = time_time();
                //fprintf(stderr, "unit=%%x N= %%i %%i %%i S = %%i %%i %%i %%i %%i %%i\n", unit,
                //Nz1, Nz0, Nx1,
                //sy_0, sy_1,
                //sx_0, sx_1,
                //sz_0, sz_1
                //);
                switch(unit)
                {
                    case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y,
                                       &sy_0, x, &sx_0, &b, z, &sz_0); break;
                    case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y,
                                       &sy_0, x, &sx_1, &b, z, &sz_0); break;
                    case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y,
                                       &sy_1, x, &sx_0, &b, z, &sz_0); break;
                    case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y,
                                       &sy_1, x, &sx_1, &b, z, &sz_0); break;
                    case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x,
                                       &sx_0, y, &sy_0, &b, z, &sz_1); break;
                    case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x,
                                       &sx_1, y, &sy_0, &b, z, &sz_1); break;
                    case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x,
                                       &sx_0, y, &sy_1, &b, z, &sz_1); break;
                    case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x,
                                       &sx_1, y, &sy_1, &b, z, &sz_1); break;
                    default: PyErr_SetString(PyExc_ValueError,
                                             "some matrix has no unit stride");
                             %(fail)s;
                };
                //fprintf(stderr, "Calling dgemm %%i %%i %%i %%i took %%f\n",
                //        unit, Nz1, Nz0, Nx1, time_time()- t0);
        s4   
            }
            break;
        }
        c         C` sy   t  t j |  j |  j |  j |  j |  j |  j |  j	 |  j
 |  j |  j |  j |  j |  j |  j |  j |  j |  j f d  S(   Nt    (   R   t   strt   __add__t
   declare_NSt   check_xyz_rank2t   setup_z_Nz_Szt   check_xyz_double_or_floatt   check_ab_double_or_floatt
   check_dimst   check_stridest   encode_strides_in_unitt   compute_stridest   begin_switch_typenumt
   case_floatt   case_float_ab_constantst   case_float_gemmt   case_doublet   case_double_ab_constantst   case_double_gemmt   end_switch_typenum(   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   build_gemm_call  s$    	c         C` s   d t    f S(   Ni   (   R   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   build_gemm_version
  s    (    (   R;   RP   RQ   RR   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   R   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR     s0   						'	+	t   Gemmc           B` s   e  Z d  Z d Z d Z d Z d Z d Z d Z d   Z	 d   Z
 d	   Z d
   Z d   Z d   Z d   Z d Z d Z d Z d Z d   Z d   Z RS(   s   In-place version of matrix-matrix multiplication (with accumulation).

    When a and b are scalars and x, y, and z are matrices, then

        gemm(z,a,x,y,b)

    is similar to

        b*z + a*dot(x,y)

    The difference between the two is that the top form is destructive
    on z, whereas the bottom form is not.  Gemm works in-place on the
    storage associated with z, and the L{Variable} returned by Gemm
    has a storage that will be aliased to the storage of the z
    argument. Because of this in-place computation, an L{Apply} of
    this op will destroy the L{Variable} z on which it operates.  (See
    L{DestructiveOps} for an explanation of what destroying means in
    the context of theano graphs. See L{BlasLapackSupport} for more
    optimized linear algebra operations.)

    s   gemm only works for rank 2s   gemm requires scalar arguments   argument z aliased to x or ys   gemm requires matching dtypess#   gemm requires floating-point dtypesR6   c         C` sD   | |  _  |  j  r4 i d g d 6|  _ |  j |  _ n |  j |  _ d  S(   Ni    (   R6   R7   t   setup_z_Nz_Sz_inplaceR   t   setup_z_Nz_Sz_outplace(   R8   R6   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR9   -  s
    		c         C` s,   |  j  r d } n d } d |  j j | f S(   NR6   t
   no_inplaces   %s{%s}(   R6   R:   R;   (   R8   t   inplace_str(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR<   5  s    		c         C` sf   |  j  j |  |  j r( |  j |  _ n |  j |  _ d |  j  k rb |  j rb i d g d 6|  _ n  d  S(   NR7   i    (   t   __dict__t   updateR6   R   R   R   R7   (   R8   t   dct(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   __setstate__<  s    	c         C` s    |  j  j   } | j d  | S(   NR   (   R   RW   t   pop(   R8   Rz   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   __getstate__H  s    c         G` s  t  t t j |   } t |  d k rI t d |  t |  f   n  | \ } } } } } t | d t  r t j |  } n  g  | | | f D] } t	 t
 |   ^ q \ } }	 }
 | j |	  r t t j | | f   n  | j |
  rt t j | | f   n  | j d k r1t t j |   n  | j d k rUt t j |   n  | j d k ryt t j |   n  | j d k rt t j |   n  | j d k rt t j |   n  | j | j k o| j k o| j k o| j k n s5t t j | j | j | j | j | j f   n  | j j d  rs| j j d  rst t j | j   n  | j   } t |  | | g  S(   Ni   s2   Wrong number of inputs for %s (expected 5, got %s)t   cachedi   i    R(   t   complex(   t   listt   mapR-   R=   RU   R>   t   getattrR'   RW   RV   R	   t   intersectionR   R   t   E_z_uniqR?   t   E_rankt   E_scalarR,   t   E_mixedRo   t   E_floatR@   R   (   R8   RJ   t   zt   aR2   R1   t   bt   it   zrt   xrt   yrt   output(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRC   O  s>    7A	*c   
      C` s  | \ } } } } } | \ }	 | j  d k s3 t  | j  d k sH t  |  j s` | j   } n  | j  d k r | j | | | t j | |   | |	 d <n| d k r| d k r t j | |  | (q| d k r t j | |  | (q| t j | |  | (n | d k r| d k r=| t j | |  7} q| d k rb| t j | |  8} q| | t j | |  7} n$ | | 9} | | t j | |  7} | |	 d <d  S(   Ni    g        g      ?g      (    (    (    (   RD   Rw   R6   RW   t   itemsetR)   RG   (
   R8   RI   RY   RL   R   R   R2   R1   R   t   zout(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRM     s0    		%
c         C` s   | d g S(   Ni    (    (   R8   RI   RN   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRO     s    s&  
        if (%(_zout)s != %(_z)s)
        {
            if (%(_zout)s)
            {
                Py_DECREF(%(_zout)s);
            }
            %(_zout)s = %(_z)s;
            Py_INCREF(%(_zout)s);
        }
        Nz = PyArray_DIMS(%(_z)s);
        Sz = PyArray_STRIDES(%(_z)s);
        s
  
        if ((NULL == %(_zout)s)
            || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
            || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
            || (PyArray_STRIDES(%(_zout)s)[0] <= 0)
            || (PyArray_STRIDES(%(_zout)s)[1] <= 0)
            || (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
            || (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
            || ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
                && (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
        {
            Py_XDECREF(%(_zout)s);
            npy_intp dims[2];
            dims[0] = PyArray_DIMS(%(_z)s)[0];
            dims[1] = PyArray_DIMS(%(_z)s)[1];
            %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
                                                          PyArray_TYPE(%(_z)s));
            //fprintf(stderr, "Gemm Allocating %%i %%i\n", dims[0], dims[1]);
            if(!%(_zout)s) {
                PyErr_SetString(PyExc_MemoryError,
                                "failed to alloc gemm_no_inplace output");
                %(fail)s
            }
        }
        Nz = PyArray_DIMS(%(_zout)s);
        Sz = PyArray_STRIDES(%(_zout)s);

        if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
        {
            float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
            int zoi = Sz[0] / sizeof(float);
            int zoj = Sz[1] / sizeof(float);
            const float * zdata = (float*)PyArray_DATA(%(_z)s);
            int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
            int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
            for (int i = 0; i < Nz[0]; ++i)
            {
                for (int j = 0; j < Nz[1]; ++j)
                {
                    zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
                }
            }
        }
        else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
        {
            double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
            int zoi = Sz[0] / sizeof(double);
            int zoj = Sz[1] / sizeof(double);
            const double * zdata = (double*)PyArray_DATA(%(_z)s);
            int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
            int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
            for (int i = 0; i < Nz[0]; ++i)
            {
                for (int j = 0; j < Nz[1]; ++j)
                {
                    zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
                }
            }
        }
        else
        {
            PyErr_SetString(PyExc_AssertionError,
                            "neither float nor double dtype");
            %(fail)s
        }
        s  
        #define REAL float
        float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
        ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
        float b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ?
        (REAL)(((float*)PyArray_DATA(%(_b)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]);
        #undef REAL
        s  
        #define REAL double
        double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
        ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
        double b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ?
        (REAL)(((float*)PyArray_DATA(%(_b)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]);
        #undef REAL
        c         C` sv   | \ } } } }	 }
 | \ } | j  d j j j d  rV t j d |  j j   n  |  j   t	 t
   |  } | S(   Ni    R   s	   %s.c_code(   RJ   R@   R,   Ro   R   t   MethodNotDefinedR:   R;   R   t   dictt   locals(   R8   RI   t   nameRY   RL   t   subt   _zt   _at   _xt   _yt   _bt   _zoutt	   full_code(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_code  s    		c         C` s"   |  j    } | r d | S| Sd  S(   Ni   (   i   (   R   (   R8   t   gv(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_code_cache_version  s    (   s   inplace(   R;   RP   RQ   R   R   R   R   R   RR   R9   R<   R   R   RC   RM   RO   R   R   R   R   R   R   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR     s(   					2		C
		t   gemm_inplacet   gemm_no_inplacec         C` sI   | d  k	 r$ t |  j  | k } n t } |  j oH |  j j | k oH | S(   N(   R%   RU   t   clientsR.   t   ownert   op(   RI   R   t
   maxclientst   retval(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   res_is_a   s    	c         C` s   | d k r t j } n  t j |  j j  r x5 |  j rd t |  j j	 t
 j  rd |  j j d }  q0 W|  j j r |  j   } n |  } | j j t j j k r t j j |  j |  | k r t
 j | |  Sd Sn  | Sd S(   sE   Return None or a TensorVariable whose type is in T.float_scalar_typesi    N(   R%   R   t   floatXR)   t   allR@   t   broadcastableR   t
   isinstanceR   R-   t
   DimShuffleRJ   t
   dimshuffleR,   Rd   t   tensort   integer_dtypest   scalart   upcastt   cast(   t   resR,   Rz   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt
   _as_scalar+  s    $c         C` sN   |  j  j d k oM |  j  j d k oM |  j  j d t k oM |  j  j d t k S(   Nt   float16R   R   i   i    i   (   R   s   float32s   float64(   R@   R,   R?   R   R'   (   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   _is_real_matrixE  s    c         C` s8   |  j  j d k o7 |  j  j d k o7 |  j  j d t k S(   NR   R   R   i   i    (   s   float16s   float32s   float64(   R@   R,   R?   R   R'   (   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   _is_real_vectorL  s    c         C` s"  | j  rR | j  j t k rR | j  j \ } } t | | | | |   g } | | f S| j  rt | j  j t j  r| j  j d j  rt | j  j d j  j t  r| j  j d } | j  j j	 d k r| j  j \ }	 }
 t | j
 d d  | |	 |
 |   } | j
 d  g } | | f S| j  j j	 d k r| j  j \ }	 }
 t | j
 d d  | |	 |
 |   } | j
 d  g } | | f St | j  j j	  d k r| j  j \ }	 }
 t | j
 d d  | |	 |
 |   } | j
   g } | | f Sn  t rt | t d  r| j  j \ } } } } } t | t d  r| j  j \ } } t t | | | | | |   | | | | d  g } | S| | k rt | | | | | | | |   g } | Sd | k r|  | t | | | | | | |  g } | Sn  | rt | | |  | d t St t f Sd  S(   Ni    R2   i   g      ?t   recurse_flip(   i    (   i   (   R   R   t   _dot22RJ   R   R   R-   R   t   Dot22t	   new_orderR   RU   R'   R   t   _beta_L_plus_alpha_M(   RB   Rk   RA   t   MR   t   Mlt   MrRz   t   MMt   MMlt   MMrt   gt   GR   t   ut   vR   R2   R1   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   R  sT    
!

'+c         ` sL    f d   } y |  j  j Wn t k
 r1 d  SX|  j  j d k sV |  j  j d k rm | j | |    | S| r t t |  d	 g    | k r | j   |  f  | S|  j	 r|  j	 j
 t j k rt |  j	 j d
   | d  t |  j	 j d   | d  nD|  j	 rO|  j	 j
 t j k rOx#|  j	 j D] } t |   | d  q/Wn|  j	 r|  j	 j
 t j k rt |  j	 j d
   | d  n|  j	 r5|  j	 j
 t j k r5g  } g  } g  } x |  j	 j D] } t j | j  j  rPx5 | j	 rt | j	 j
 t j  r| j	 j d
 } qW| j  j r@| j | j    q| j |  qt |  rl| j |  qt |  r| j |  q| j   |  f  | SqWt |  d k rat |  d
 k st  | d
 }	 t |  d
 k rt |	   | d  q2t |  d k r1t |	 | | d
  | d  q2t |	 t j | | d
  | d  | d  qHt |  d k rt |  d
 k st  | d
 }
 t |  d
 k rt |
   | d  q2t |  d k rt |
 | | d
  | d  q2t |
 t j | | d
  | d  | d  qH| j   |  f  n | j   |  f  | S(   Nc         ` s?     d k r |  S  d k r3 |  j  j d k r3 |  S  |  Sd  S(   Ni   it   bool(   R@   R,   (   t   thing(   t   scale(    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   scaled  s
    i   i   R   R   R   R   R   R   i    (   i   i   (   s   float16s   float32s   float64s	   complex64s
   complex128(   R@   R   Rx   R%   R?   R,   Ry   RU   R   R   R   R-   R   t   _gemm_canonicalizeRJ   t   addt   negt   mulR)   R   R   R   R   R   R   Rw   (   t   rR   Rz   R   R   R   t   scalarst   vectorst   matricest   mR   (    (   R   s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR     sv    		$!!$
 #
 c         C` s  t  |   }  d } x | t |   d k  r y |  | \ } } Wn t k
 r_ | d 7} q n X| d } x | t |   k  r y |  | \ } } Wn t k
 r | d 7} qm n X| | k r | | } | | f |  | <|  | =qm | d 7} qm W| d 7} q W|  S(   Ni    i   (   R   RU   Rx   (   t   lstR   t   s_it   M_it   jt   s_jt   M_j(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   _factor_canonicalized  s,    





c         C` s  g  } x |  D] } t  | t  r | \ } } t j |  } t j j | j | j  | j k r | j t j	 | | j  | d f  q q q W| }  d   } x4t
 t |   d  D]} |  | \ } } xt
 | d t |    D] }	 |  |	 \ }
 } | j | j k rq n  t | | |
 |  \ } } | r t |  d k sNt  g  t |   D]* \ } } | | |	 f k r[| |  ^ q[} | j |  t |  d k rt j |   g } n | } | | f Sq Wq Wd S(   s;   
    Returns None, or a list to replace node.outputs.

    i   c         S` sN   y |  \ } } Wn t  k
 r$ |  SX| d k r5 | S| d k rF | S| | S(   Ni   i(   Rx   (   R   t   sR   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   item_to_var+  s    N(   R   t   tupleR-   R=   Rd   R   R   R,   Ry   R   R   RU   R@   R   Rw   t	   enumeratet   extendR   (   R  t   lst2t   sMt   sm0t   sm1R  R   R  R  R  R	  R
  t   gemm_of_sM_listt	   old_dot22t   kt   inputt
   add_inputsRz   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   _gemm_from_factored_list  s4    $0	 	:c         C` s   g  } t  j    } t |  j d d | d  t  j    } t |  d k r t |  } t  j    } t |  } t  j    } | r | d d j |  j d j k r | | | | | | | f Sn  d | | d d f S(   s&  
    :todo: In many expressions, there are many ways to turn it into a
        gemm.  For example dot(a,b) + c + d.  This function should
        return all of them, so that if one version of gemm causes a
        cycle in the graph, then another application of gemm can be
        tried.

    i    g      ?i   N(   t   timeR   t   outputsRU   R  R  R@   R%   (   RI   R  R   R   R   Rz   t   t3(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   _gemm_from_node2S  s    	't   GemmOptimizerc           B` s;   e  Z d  Z d   Z d   Z d   Z e d d   Z RS(   s.   Graph optimizer for inserting Gemm operations.c         C` s   t  j |   t |  _ d  S(   N(   R   R9   R'   t   warned(   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR9   y  s    c         C` s   | j  t j    d  S(   N(   t   attach_featureR   t   ReplaceValidate(   R8   t   fgraph(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   add_requirements}  s    c         ` sv  t  } d } d } d } d } d } d } d }	 d }
 d } | j rl | j j } | j j   } | j } n     f d   } t j j j	 | d  d  d d } | j |  x| r| d 7} t j   } t j j j | j | j   | t j   | 7} t }  j   x D]  t   j t j  oet   j j t j j t j j t j j t j j f  snqn    | j k rqn  y: t    \ } } } } | | 7} |	 | 7}	 |
 | 7}
 Wn t k
 r| d 7} qn X| r| \ } } t  |  t    j  k st!  t" | d j# _$ yE | j% t& t'   j |   | g d d d t t  } | d 7} Wqt k
 r| d 7} qt( k
 r| d 7} t  |  _) qXqqWq W| j* |  | j r9| j j | } | j | } i  } x] t+ | j  D]7 \ } } | | k r(| | | | | <q| | | <qWn d  } d  } i  } |  | | | | | | |	 |
 | | | | f S(   Ni    c         ` s    |    k	 r  j  |   n  d  S(   N(   Ry   (   t   new_node(   RI   t   nodelist(    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt	   on_import  s    R   R  i   t   reasont   warn(,   R.   t   profilet   validate_timet   execute_callbacks_timesRW   t   execute_callbacks_timeRd   t   goft   optt   UpdaterR%   R!  R  t   grapht   io_toposortRJ   R  R'   t   reverseR   R   R-   t   Elemwiset	   scalar_opR   t   Addt   Subt   Negt   Mult   apply_nodesR  R   RU   Rw   R   t   tagt   values_eq_approxt   replace_all_validate_removeR   t   zipR   R   t   remove_featureR   (   R8   R#  t   did_somethingt   nb_itert   nb_replacementt   nb_replacement_didn_t_removet   nb_inconsistency_maket   nb_inconsistency_replacet   time_canonicalizet   time_factor_cant   time_factor_listt   time_toposortt   validate_beforet   callbacks_beforet   callback_beforeR'  R   R   t   new_outputst   time1t   time2t   time3R  R+  t   callback_timet   callbacks_timeR  R   (    (   RI   R&  s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   apply  s    			




!	
	i    c         C` s  d | } t  | d d |  t  | d | d d |  t  | d | d d |  t  | d | d	 d |  t  | d
 | d d |  t  | d | d d |  t  | d | d d |  t  | d | d d |  t  | d | d d |  t  | d | d d |  t  | d | d d |  t  | d | d d |  | d d k rt  | d d |  xG t t | d  d d   D]# } | d d k r~t  |  q~q~Wn  d  S(   Ns       R  t   files    nb_iteri   s    nb_replacementi   s    nb_replacement_didn_t_removei   s    nb_inconsistency_makei   s    nb_inconsistency_replacei   s    time_canonicalizei   s    time_factor_cani   s    time_factor_listi   s    time_toposorti	   s    validate_timei
   s    callback_timei   s    callbacks_timei   t   keyc         S` s   |  d S(   Ni   (    (   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   <lambda>  s    i    (   t   printt   sortedR   (   t   streamt   proft   levelt   blancR   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   print_profile  s$    
&(   R;   RP   RQ   R9   R$  RS  t   staticmethodR]  (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR  w  s   			[R   c           B` sS   e  Z d  Z d   Z d   Z d   Z d Z d Z d Z d Z	 d   Z
 d	   Z RS(
   s_   Compute a matrix-matrix product.

    This is a specialization of the more general Dot().

    c         C` s   d
 } | j  j d k s* | j  j | k r9 t |   n  | j  j d k s] | j  j | k rl t |   n  | j  j | j  j k r t d   n  | j  j d | j  j d	 f } t j | j  j |  g } t |  | | g |  S(   NR   R   R   R   R   i   s   dtype mismatch to Dot22i    i   (   s   float16s   float32s   float64s	   complex64s
   complex128(   R@   R?   R,   R>   R   R-   R   R   (   R8   R2   R1   t   dtypest   bzR  (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRC     s    $$ c         C` sq   | \ } } | \ } y# t  j t  j | |   | d <Wn2 t k
 rl } | j | j | j f | _   n Xd  S(   Ni    (   R)   RH   RG   RE   t   argsRD   (   R8   RI   RY   RL   R2   R1   R   t   e(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRM     s    	#c         C` s   | d d | d d g g S(   Ni    i   (    (   R8   RI   RN   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRO     s    sM  
        if ((NULL == %(_zout)s)
            || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_x)s)[0])
            || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_y)s)[1]))
        {
            if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
            npy_intp dims[2];
            dims[0] = PyArray_DIMS(%(_x)s)[0];
            dims[1] = PyArray_DIMS(%(_y)s)[1];
            %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
                            PyArray_TYPE(%(_x)s));
            //fprintf(stderr, "Dot Allocating %%i %%i\n", dims[0], dims[1]);
            if(!%(_zout)s) {
                PyErr_SetString(PyExc_MemoryError,
                                "failed to alloc dot22 output");
                %(fail)s
            }
        }
        Nz = PyArray_DIMS(%(_zout)s);
        Sz = PyArray_STRIDES(%(_zout)s);

        R   sG   
                float a = 1.0;
                float b = 0.0;
        sI   
                double a = 1.0;
                double b = 0.0;
        c   
      C` s   | \ } } | \ } | j  d j j j d  rM t j d |  j j   n  t |  j	    d k r t
 t |   j | | | | f | f |  S|  j   t t   |  }	 |	 S(   Ni    R   s	   %s.c_code(   RJ   R@   R,   Ro   R   R   R:   R;   RU   R   t   superR   R   R   R   R   (
   R8   RI   R   RY   RL   R   R   R   R   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   2  s    		c         C` s"   |  j    } | r d | S| Sd  S(   Ni   (   i   (   R   (   R8   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   >  s    (   R;   RP   RQ   RC   RM   RO   R   R   R   R   R   R   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR     s   				c         C` s  t  |  j t j  s d  S|  j \ } } | j j | j j k rc t j d | | | j | j  d  S| j j d k rx| j	 d k r | j	 d k r t
 |  j   g S| j	 d k r | j	 d k r t
 | | j d	 d
   j d	  g S| j	 d k r)| j	 d k r)t
 | j d
 d	  |  j d  g S| j	 d k rx| j	 d k rxt
 | j d
 d	  | j d	 d
   j   g Sn  t j d | | | j | j  d  S(   Ns*   Not optimizing dot with inputs %s %s %s %sR   R   R   R   R   i   i   i    R2   (   s   float16s   float32s   float64s	   complex64s
   complex128(   R   R   R-   t   DotRJ   R@   R,   Ru   t   infoR?   R   R   (   RI   R2   R1   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_dot_to_dot22H  s&    	%%	c         C` s#   |  j  t k r t |  j   g Sd  S(   N(   R   R   R   RJ   (   RI   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_inplace_gemmi  s    c         C` s#   |  j  t k r t |  j   g Sd  S(   N(   R   t   gemv_no_inplacet   gemv_inplaceRJ   (   RI   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_inplace_gemvo  s    c         C` s#   |  j  t k r t |  j   g Sd  S(   N(   R   t   gert   ger_destructiveRJ   (   RI   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_inplace_geru  s    c         C` s   |  j  t k r |  j \ } } } } } | j | j k oJ t t f k n r t | j d  | | j | j d  |  } | j d d  g S| j | j k o t t f k n r t | j d  | | | j d  |  } | j d d  g Sn  d S(   s.   GEMM acting on row or column matrices -> GEMV.i   R2   i    N(	   R   R   RJ   R   R.   R'   Rh  R   R-   (   RI   R   R   R2   R1   R   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_gemm_to_gemv{  s    (-(*c         C` s  |  j  t k r|  j \ } } } } } | j d r| j d r| j d  } | j d  } y t j |  } Wn t j k
 r d SX| d k r t | | | |  }	 |	 g S| d k rt j	 | j
 d | j
 d g | j  }
 t |
 | | |  }	 |	 g Sd Sqn  d S(   s'   GEMM computing an outer-product -> GER.i   i    N(   R   R   RJ   R   R   R-   t   get_scalar_constant_valuet   NotScalarConstantErrorRk  t   zerosRD   R,   (   RI   R   R   R2   R1   R   t   xvt   yvt   bvalRz   Rq  (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_gemm_to_ger  s"    )c         C` sQ  |  j  t k rM|  j \ } } | j } | j } t j t j d d | j  } t j t j d d | j  } | d r | d r | j	 d  } | j	 d  } t j
 | j d | j d g d | j }	 t |	 | | |  }
 |
 g S| d rU| d rU| j	 d  } t j | j  d  }	 t |	 | | j | |  }
 |
 j	 d d  g S| d r| d r| d r| j	 d  } t j | j  | j d  }	 t |	 | | j | |  }
 |
 j	 d d  g S| d rM| d rM| d rM| j	 d  } t j | j  | j d  }	 t |	 | | | |  }
 |
 j	 d d  g Sn  d S(   s(   dot22 computing an outer-product -> GER.i   R,   i    R2   N(   R   R   RJ   R   R-   R=   R)   RH   R,   R   Rq  RD   Rk  t
   AllocEmptyRh  (   RI   R2   R1   t   xbt   ybt   onet   zeroRr  Rs  Rq  Rz   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_dot22_to_ger_or_gemv  s6    		!!,  t   BlasOptg333333?t   fast_runt   fast_compileRf  t   gemm_optimizeri
   Rn  t   max_use_ratioi   t   ignore_newtreesi   R   t   blas_opt_inplacet   InplaceBlasOptg     Q@t   Dot22Scalarc           B` sV   e  Z d  Z d   Z d   Z d   Z e j Z d Z d Z	 d Z
 d   Z d   Z RS(	   s   Compute a matrix-matrix product.

    This is a specialization of the more general Dot()
    Used to call optimized gemm implementation.
    Also used to generate a gemm later.
    compute scalar*dot(x,y).

    c         C` sD  | j  d k r$ t t j |   n  | j  d k rH t t j |   n  | j  d k rl t t j |   n  | j | j k o | j k n s t d | j | j | j f   n  | j j d  r | j j d  r t d | j   n  | j j d | j j d g } t	 j
 | j j |  g } t |  | | | g |  S(   Ni    i   s$   Dot22Scalar requires matching dtypesR(   R   s*   Dot22Scalar requires float or complex argsi   (   R?   R>   R   R   R   R,   Ro   R@   R   R-   R   R   (   R8   R2   R1   R   R`  R  (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRC     s     % c   	      C` sx   | \ } } } | \ } y' t  j | t  j | |   | d <Wn2 t k
 rs } | j | j | j f | _   n Xd  S(   Ni    (   R)   RH   RG   RE   Ra  RD   (	   R8   RI   RY   RL   R2   R1   R   R   Rb  (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRM     s    	'c         C` s   | d d | d d g g S(   Ni    i   (    (   R8   RI   RN   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRO     s    s   
        if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError,
                         "type(a) is not double or float"); %(fail)s;}

        s   
        #define REAL float
        float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
        ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
        #undef REAL
        float b = 0.0;
        s   
        #define REAL double
        double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
        ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
        : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
        #undef REAL
        double b = 0.0;
        c         C` s   | \ } } } | \ }	 | j  d j j j d  rP t j d |  j j   n  t |  j	    d k r t
 t |   j | | | | f |	 f |  S|  j   t t   |  }
 |
 S(   Ni    R   s	   %s.c_code(   RJ   R@   R,   Ro   R   R   R:   R;   RU   R   Rc  R  R   R   R   R   (   R8   RI   R   RY   RL   R   R   R   R   R   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   <  s    		c         C` s"   |  j    } | r d | S| Sd  S(   Ni   (   i   (   R   (   R8   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   H  s    (   R;   RP   RQ   RC   RM   RO   R   R   R   R   R   R   R   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR    s   						c      	   C` s  |  j  t j k r t Sg  |  j D]! } | j o> | j j  t k ^ q  } t |  sW t S| j t	  d k ro n  | j
 t	  } |  j | } g  |  j D] } t | d | j ^ q } t |  sg  |  j D]X } | j o!| j j  t j k o!t g  | j j D] } t | d | j ^ q  ^ q } t |  s:t S| j
 t	  } |  j | }	 d }
 xl t |	 j j  D]X \ } } t | d | j rot j j | j j | j j  | j j k ro| }
 PqoqoW|
 d k  r
t j d |  j g  |  j D] } | j ^ q t St j t |	 j j |
 d | j | j j  } | j j sNt  t | j j d | j j d |  } | | k st  g  t |  j  D]$ \ } } | | | f k r| ^ q} g  t |	 j j  D] \ } } | |
 k r| ^ q} t j | | |  g Sd }
 xp t |  j  D]_ \ } } | | k r$| | d k	 r$t j j | j j | j j  | j j k r$| }
 Pq$q$W|
 d k  rt j d |  j g  |  j D] } | j ^ q t S|
 t |  j  k  st  |  j |
 } t j |  j  } | j |  | j |  t j | |
 | j j  } | j j sIt  t |  d k rt | j j d | j j d |  g St j t | j j d | j j d |  |  g Sd S(   s<  
    Notes
    -----
    Previous attempts to alter this optimization to replace dot22 with
    gemm instead of dot22scalar resulted in some Scan nodes being
    duplicated and the ScanSaveMem optimization never running on them,
    resulting in highly increased memory usage. Until this issue is
    resolved, this optimization should keep using dot22scalar instead of
    gemm.

    We upcast the scalar if after the multiplication with the dot this give
    the same type.

    We execute this optimizer after the gemm optimizer. This
    allow to give more priority to gemm that give more speed up
    then this optimizer, but allow the gemm optimizer to ignore
    this op.

    TODO: support when we can reorder the mul to generate a
    dot22scalar or fix the canonizer to merge them(1 mul with multiple
    inputs)

    i   R,   ii    sg   Not optimizing dot22 with inputs %s %s, as the type of the scalar cannot be upcasted to the matrix typeN(   R   R-   R   R'   RJ   R   R   R0   t   countR.   t   indexR   R,   R  Rd   R   R   R@   Ru   Re  R   R?   Rw   t   _dot22scalarR%   RU   RW   t   remove(   RI   R2   t   i_dot22t	   dot22_idxR}   t   i_scalart   x_it   i_mult   mul_idxR  t
   scalar_idxR   R   RG   t   inptt   other_factorst   other_m_inputsR  t   o(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_dot22_to_dot22scalarR  sz    1+e$	&&	&'R  i   t
   BatchedDotc           B` s   e  Z d  Z d Z d   Z d   Z d   Z d   Z d   Z d   Z	 d   Z
 d   Z d	   Z d
   Z d   Z d   Z d   Z RS(   sl   
    Computes the batched dot product of two variables:

        batched_dot(a, b)[i] = dot(a[i], b[i])
    c         G` s^  t  t t j |   } t |  d k rC t d t |    n  | d j d	 k rp t d | d j   n  | d j d
 k r t d | d j   n  t j j	 g  | D] } | j
 j ^ q   } g  | D] } t j | |  ^ q } | d j
 j d p| d j
 j d f | d j
 j d d !| d j
 j d } t |  | t j | |  g  S(   Ni   s>   theano.tensor.blas.BatchedDot: 2 arguments required, %d given i    i   s   theano.tensor.blas.BatchedDot: input 0 (0-indexed) must have ndim of 2 or 3, %d given. Consider calling theano.tensor.batched_dot instead.i   s   theano.tensor.blas.BatchedDot: input 1 (0-indexed) must have ndim of 2 or 3, %d given. Consider calling theano.tensor.batched_dot instead.i(   i   i   (   i   i   (   R   R   R-   R=   RU   R>   R?   Rd   R   R   R@   R,   R   R   R   R   (   R8   RJ   R  R,   t   upcasted_inputsR   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRC     s    +%=c      	   C` s  | \ } } | \ } | j  d | j  d k r t d d j t t |   d j g  | D] } t | j  d  ^ qW  f   n  |  j | g  | D] } | j  ^ q  d } | j d j }	 t j	 | d |	 }
 | d <x9 t
 |
 j  d  D]$ } t j | | | |  |
 | <q Wd  S(   Ni    sb   theano.tensor.blas.BatchedDot: inputs [%s] must have the same size in axis 0, but have sizes [%s].s   , R,   (   RD   R>   t   joinR   R   RO   R  R,   R)   t   emptyR   RG   (   R8   RI   RY   RL   R2   R1   R   R   RD   R,   t   z0(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRM     s    	9,c         C` s   d } t    | S(   Ns  
        template<typename dtype, typename function>
        bool batch_gemm(function gemm, int type_size,
                        PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) {
            npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
            npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
            npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);

            if (Nx[0] != Ny[0]) {
                PyErr_Format(PyExc_ValueError,
                             "Shape mismatch: batch sizes unequal."
                             " x.shape is (%d, %d, %d),"
                             " y.shape is (%d, %d, %d).",
                             Nx[0], Nx[1], Nx[2],
                             Ny[0], Ny[1], Ny[2]);
                return 1;
            }

            if (Nx[2] != Ny[1]) {
                PyErr_Format(PyExc_ValueError,
                             "Shape mismatch: summation axis sizes unequal."
                             " x.shape is (%d, %d, %d),"
                             " y.shape is (%d, %d, %d).",
                             Nx[0], Nx[1], Nx[2],
                             Ny[0], Ny[1], Ny[2]);
                return 1;
            }

            /* encode the stride structure of _x,_y,_z into a single integer. */
            int unit = 0;
            unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
            unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
            unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;

            /* create appropriate strides for malformed matrices that are row or column
             * vectors, or empty matrices.
             * In that case, the value of the stride does not really matter, but
             * some versions of BLAS insist that:
             *  - they are not smaller than the number of elements in the array,
             *  - they are not 0.
             */
            int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
            int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
            int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
            int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
            int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
            int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);

            dtype* x = (dtype*)PyArray_DATA(xs);
            dtype* y = (dtype*)PyArray_DATA(ys);
            dtype* z = (dtype*)PyArray_DATA(zs);

            dtype a = 1.0;
            dtype b = 0.0;
            char N = 'N';
            char T = 'T';
            int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];

            // loop over batch axis
            for (int i = 0; i < Nz[0]; i++) {
                switch(unit)
                {
                    case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
                    case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
                    case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
                    case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
                    case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
                    case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
                    case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
                    case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
                    default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1;
                };
                x += Sx[0] / type_size;
                y += Sy[0] / type_size;
                z += Sz[0] / type_size;
            }

            return 0;
        }
        (   R   (   R8   t   batch_gemm_defn(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR     s    Pc         C` s   t    S(   N(   Rf   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   N  s    c         C` s   t  d t d t  S(   NR`   Ra   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   Q  s    c         C` s   t  d t d t  S(   NR`   Rb   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   T  s    c         C` s   t  d t d t  S(   NR`   Rc   (   Rf   R'   R.   (   R8   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   W  s    c         C` s   d S(   Ns   
        // clean up views
        Py_XDECREF(xs); xs = 0;
        Py_XDECREF(ys); ys = 0;
        Py_XDECREF(zs); zs = 0;
        (    (   R8   RI   R   RJ   R  R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_code_cleanupZ  s    c         ` s  | \ } } | \   | d  d   } | j  d j | j  d j | j d j }	 }
 } d | g } |	 d k r | j d |  n  |
 d k r | j d |  n  t |  | k s t  d	 j   f d
   t |  D  } d j |  } |   |  } d t   } g  } xI | |	 f | |
 f g D]/ \ } } | | |  } | j d t    q.Wd j |  }  f d   } g  } |	 d k r| j d  n( |	 d k r| j | d | d   n  |
 d k r| j d  n( |
 d k r| j | d | d   n  | d k r)| j d  nF | j | d   d |	 d k rMd  n d |
 d k rbd  n d f   d j |  t   } d t   S(   Nt   failc         ` s   d |    | d k r& d j  d    Sd j d j   f d   t d |  D  d d j   f d	   t d |  D  g  S(
   Ns   PyArray_STRIDES(%s)i   s   {strides}[0] == type_sizet   stridess    && c         3` s'   |  ] } d  j  d   d |  Vq d S(   s5   {strides}[{i}] > 0 && {strides}[{i}] % type_size == 0R  R   N(   t   format(   t   .0R   (   R  (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pys	   <genexpr>m  s   s   (%s)s    || c         3` s'   |  ] } d  j  d   d |  Vq d S(   s   {strides}[{i}] == type_sizeR  R   N(   R  (   R  R   (   R  (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pys	   <genexpr>o  s   (   R  R  t   range(   t   varR?   (    (   R  s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt
   contiguoush  s    
i    i   s   PyArray_DIMS(%s)[0]i   s   PyArray_DIMS(%s)[1]s   PyArray_DIMS(%s)[2]s    && c         3` s(   |  ] \ } } d    | | f Vq d S(   s   PyArray_DIMS(%s)[%i] == %sN(    (   R  R   t   dim(   R   (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pys	   <genexpr>}  s   s   , s  
            if (NULL == %(_z)s || !(%(z_shape_correct)s)  || !(%(z_contiguous)s))
            {
                npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
                Py_XDECREF(%(_z)s);
                %(_z)s = (PyArrayObject*)PyArray_SimpleNew(
                    %(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
                if(!%(_z)s) {
                    PyErr_SetString(PyExc_MemoryError,
                                    "failed to alloc BatchedDot output");
                    %(fail)s
                }
            }
        s-  
                if (!(%(_contiguous)s)) {
                    PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
                    if (!_copy)
                        %(fail)s
                    Py_XDECREF(%(var)s);
                    %(var)s = _copy;
                }
            s   
c         ` s0    } d j    f d   | D  } d t   S(   Ns   , c         3` s1   |  ]' } | d k r d  n d   | f Vq d S(   t   1s   PyArray_DIMS(%s)[%i]N(   R%   (   R  t   axis(   t   oldname(    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pys	   <genexpr>  s   s  {
                npy_intp dims[3] = {%(_shape)s};
                PyArray_Dims newshape = {dims, 3};
                %(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
                if (!%(newname)s)
                    %(_fail)s
                // make sure we didn't accidentally copy
                assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
            }(   R  R   (   t   newnameR  RD   t   _failt   _shape(   R  (   R  s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   c_dimshuffle  s    	s   xs = %(_x)s; Py_XINCREF(xs);i   t   xss   ys = %(_y)s; Py_XINCREF(ys);t   yss   zs = %(_z)s; Py_XINCREF(zs);t   zssT
  
        int type_num = PyArray_DESCR(%(_x)s)->type_num;
        int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes

        // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
        PyArrayObject *xs = 0, *ys = 0, *zs = 0;

        if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(x) != %(x_ndim)s. rank(x) is %%d.",
                         PyArray_NDIM(%(_x)s));
            %(fail)s;
        }
        if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(y) != %(y_ndim)s. rank(y) is %%d.",
                         PyArray_NDIM(%(_y)s));
            %(fail)s;
        }
        if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) {
            PyErr_Format(PyExc_NotImplementedError,
                         "rank(z) != %(z_ndim)s. rank(z) is %%d.",
                         PyArray_NDIM(%(_z)s));
            %(fail)s;
        }

        // allocate output
        %(allocate)s
        // reallocate any noncontiguous arrays or arrays with invalid strides
        %(contiguate)s
        // add dims to make sure everything is tensor3
        %(upcast)s
        // from here on, use xs, ys and zs as they are tensor3 and share memory
        // with the original %(_x)s, %(_y)s and %(_z)s arrays.

        if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(ys)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE)
            && (PyArray_DESCR(zs)->type_num != NPY_FLOAT))
        {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}

        if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num)
            ||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
        { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }

        switch (type_num)
        {
            case NPY_FLOAT:
            if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) {
                %(fail)s;
            }
            break;
            case NPY_DOUBLE:
            if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) {
                %(fail)s;
            }
            break;
        }
        (   i    Ni   (   i    i   N(
   RJ   R?   R  Ry   RU   Rw   R  R  R   R%   (   R8   RI   R   RY   RL   R   R   R   R  t   x_ndimt   y_ndimt   z_ndimt   z_dimst   z_shape_correctt   z_shapet   z_contiguoust   allocatet
   contiguateR  R?   t   _contiguousR  R   (    (   R   R  s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   b  sP    	
	2%		Bc         C` s   d d l  m } d |   f S(   Ni    (   R   i   (   t   theano.tensor.blas_headersR   (   R8   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR   	  s    c         C` s  | \ } } | \ } | j  j | j  j | j  j } } } | d k rv | j d d  | }	 | j d d  | }
 n)| d k r | d k r t j | | j d d d   }	 | j d d d  | j d d d  }
 n | d k r>| d k r>| j d d d  | j d d d  }	 t j | j d d d  |  }
 na | | k oUd k n rt j | | j d d d   }	 t j | j d d d  |  }
 n  |	 j | j k rt j |	 | j  }	 n  |
 j | j k rt j |
 | j  }
 n  |	 |
 f S(   Ni   i    R2   i   i   (   R@   R?   R   R-   t   batched_dotR   t   patternbroadcast(   R8   RY   t   gradsR2   R1   t   gzt   xdimt   ydimt   gdimt   xgradt   ygrad(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   grad	  s(    	&!+($!$c         C` s  t  |  d k s t  t  |  d k s0 t  | d d  k rW | d d  k rW d  g St j d k } | ry t j j j | d  } Wn* t	 k
 r t j j j
 d  t } n Xy t j j j | d  } Wn* t	 k
 r t j j j
 d  t } n X| d rWy t j j j | d  } WqWt	 k
 rSt j j j
 d  t } qWXn  | d ry t j j j | d  } Wqt	 k
 rt j j j
 d  t } qXqn  | rb| | g } | | g }	 x t d  D] }
 |	 |
 d  k	 r| |
 j |	 |
 j k rt d	 t |
  d
 t |
  d t | |
 j  t |	 |
 j  f   qqWn  | d r|  | d | d  } n  | d r|  | d | d  } n  | d r| d r| | g S| d r| g S| g Sd  S(   Ni   i    i   t   offs7   first input passed to BatchedDot.R_op has no test values8   second input passed to BatchedDot.R_op has no test values<   first eval point passed to BatchedDot.R_op has no test values=   second eval point passed to BatchedDot.R_op has no test values   input s    and eval_point s\    to BatchedDot.R_op should have the same shape, but their shapes are %s and %s, respectively(   RU   Rw   R%   R   t   compute_test_valueRd   R.  R   t   get_test_valuet   AttributeErrort   missing_test_messageR'   R   RD   RE   R   (   R8   RJ   t   eval_pointst   debugger_availablet   iv0t   iv1t   ev0t   ev1t   input_valuest   eval_point_valuesR   R   R   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   R_op0	  sh     



%


c         C` sN   x, | D]$ } t  |  d k r t    q q W| \ } } | d  | d g S(   Ni   i   i(   i   i   (   RU   t   NotImplementedError(   R8   RI   t   shapest   shape_t   xshpt   yshp(    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyRO   t	  s
    (    (   R;   RP   RQ   RR   RC   RM   R   R   R   R   R   R  R   R   R  R  RO   (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyR    s   			S								$	Dc         C` s,   |  j  t j t j f k r( t |   n  d  S(   N(   R   R-   R   R   R   (   RI   (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   local_print_as_we_go_along	  s    (y   RQ   t
   __future__R    R   R   RW   t   loggingRp   R  R)   t   numpy.distutilst   numpy.distutils.__config__t   ImportErrort   sixR   t	   six.movesR   R   Rd   R   t
   theano.gofR   R   R	   R
   R   R   R   R   R   R   R   t   theano.printingR   R   R   t   theano.compile.modeR   t   theano.scalart   theano.tensorR   R-   R  R   R   t   theano.tensor.optR   R   t   theano.tensor.typeR   t	   getLoggerRu   t   scipy.linalg.blast   scipyR.   R&   t   linalgRe   t   fblasR  t   sgemvR,   t   dgemvt   cgemvt   zgemvR+   Rb  R'   Rf   Rv   R   R#   R%   R$   R5   Rh  Ri  R4   RS   Rk  Rl  t   memoizeRg   R   R   R   R   t   gemmt   assignR   R   R   R   R   R   R  R  R  R  R   R   Rd  Rf  Rg  Rj  Rm  Rn  Ru  R{  t
   blas_optdbt   registerR  R  R  R   R  R  R  R   R   R  (    (    (    s2   /tmp/pip-build-X4mzal/theano/theano/tensor/blas.pyt   <module>~   s   L			P> J @ 		K	R	)	;	$zT	!+				
		
		[	q		
 	