
Xc           @` s   d  d l  m Z m Z m Z d  d l Z d  d l m Z d  d l Z d  d l	 m
 Z d  d l m Z m Z d  d l m Z d  d l m Z d  d l m Z d	 e j f d
     YZ e   Z d   Z d   Z d S(   i    (   t   absolute_importt   print_functiont   divisionN(   t   xrange(   t   basic(   t   blas_header_textt   blas_header_version(   t   ldflags(   t   strutil(   t   grad_undefinedt   Conv3Dc           B` sw   e  Z d  Z d Z d   Z d   Z d   Z d   Z d   Z d   Z	 d   Z
 d   Z d	   Z d
   Z d   Z RS(   s   
    3D `convolution` of multiple filters on a minibatch.

    Notes
    -----
    Does not flip the kernel, moves kernel with a user specified stride.

    c         C` s   d t    f S(   Ni   (   R   (   t   self(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_code_cache_version9   s    c         C` s   t  j |  } t  j |  } t  j |  } t  j |  } | j d t t t | j d f }	 t j |  d | | | | g d t  j | j |	    g }
 |
 S(   sK  
        Parameters
        ----------
        V
            Visible unit, input(batch,row,column,time,in channel)
        W
            Weights, filter(out channel,row,column,time,in channel)
        b
            bias, shape == (W.shape[0],)
        d
            strides when moving the filter over the input(dx,dy,dt)

        i    t   inputst   outputs(   t   Tt   as_tensor_variablet   broadcastablet   Falset   theanot   Applyt
   TensorTypet   dtype(   R   t   Vt   Wt   bt   dt   V_t   W_t   b_t   d_t   bcastt   node(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt	   make_node<   s    #c      
   C` s;  | \ } } } } | \ } t  j j j | t j | d d d d d  d   f  | | | j d d ! } t j | | j  } | j }	 t  j j j	 | | |	 |  }
 t j |
 | j  }
 t j
 | d d } t j | | j  } t |  d | d d  } d t |  k r)| j d  k	 r)| j } n d	 } d t |  k r\| j d  k	 r\| j } n d
 } d t |  k r| j d  k	 r| j } n d } d t |  k r| j d  k	 r| j } n d } d | d | d | _ d | d | d | d |
 _ d | d | d | d | d | _ | |
 | | g S(   Ni    i   i   t   axisi   i   s~   The gradient of Conv3D with respect to the convolution stride is undefined because Conv3D is only defined for integer strides.t   namet	   anon_dCdHt   anon_Vt   anon_Wt   anon_bs   Conv3D_dCdV(dCdH=s   ,V=t   )s   Conv3D_dCdW(dCdH=s   ,W=s   Conv3D_dCdb(dCdH=s   ,b=(   i    i   i   i   (   R   t   tensort   nnett   convTransp3DR   t
   zeros_liket   shapet   patternbroadcastR   t
   convGrad3Dt   sumR	   t   dirR#   t   None(   R   R   t   output_gradientsR   R   R   R   t   dCdHt   dCdVt   WShapet   dCdWt   dCdbt   dCddt	   dCdH_namet   V_namet   W_namet   b_name(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   gradW   s:    		A		!!!!!)c         C` s3   | \ } } } } t  | | | |  | d d <d  S(   Ni    (   t   computeH(   R   R    R   t   output_storageR   R   R   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   perform   s    c         C` s   | j  \ } } } } | \ } } }	 }
 | d } | d } | d } | d } | d } | d } | d } | d } | d } | d } | d } | | | d } | | | d } | | | d } | | | | | f } | g S(   Ni    i   i   i   (   R   (   R   R    t   input_shapesR   R   R   R   t   V_shapet   W_shapet   b_shapet   d_shapet   drt   dct   dtt
   batch_sizet   output_channelst	   vidHeightt   filterHeightt   vidWidtht   filterWidtht   vidDurt	   filterDurt   output_heightt   output_widtht
   output_durt   rval(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   infer_shape   s$    










c         C` s   t    S(   N(   R   (   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_support_code   s    c         C` s   t    S(   N(   R   (   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_libraries   s    c         C` s   t  d t d t  } | S(   Nt   libst   flags(   R   R   t   True(   R   RZ   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_compile_args   s    c         C` s   t  d t d t  S(   NRY   t   libs_dir(   R   R   R[   (   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt
   c_lib_dirs   s    c         C` s   t  d t d t  S(   NRY   t   include_dir(   R   R   R[   (   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_header_dirs   s    c         C` s   | \ } } } }	 | d }
 | d } d } | j  \ } } } } | j d } t j j j r | j | j k r | j | j k r | j d k r d } n. | j d k r d } n t d | j j   | d	 7} n  | d
 7} t	 j
 | t    S(   Nt   faili    s,  
            ///////////// < code generated by Conv3D >

            //printf("				Conv3D c code\n");

            //Check dimensionality of inputs
            if (PyArray_NDIM(%(W)s) != 5)
            {
                PyErr_Format(PyExc_ValueError, "Conv3D: W must be a 5 dimensional tensor");
                            %(fail)s

            }

            if (PyArray_NDIM(%(V)s) != 5)
            {
                PyErr_Format(PyExc_ValueError, "Conv3D: V must be a 5 dimensional tensor");
                            %(fail)s
            }

            if (PyArray_NDIM(%(b)s) != 1)
            {
                PyErr_Format(PyExc_ValueError,"Conv3D: b must be a vector.");
                %(fail)s
            }

            if (PyArray_NDIM(%(d)s) != 1)
            {
                PyErr_Format(PyExc_ValueError,"Conv3D: d must be a vector.");
                %(fail)s
            }

            if (PyArray_DIMS(%(d)s)[0] != 3)
            {
                PyErr_Format(PyExc_ValueError,"Conv3D: 3 stride length arguments expected (row, col, time) but %%li were given", (long)PyArray_DIMS(%(d)s)[0]);
                %(fail)s
            }

            //Read and check sizes of inputs
{ // exta scope so error handler jumps don't cause errors
            const int batchSize = PyArray_DIMS(%(V)s)[0];
            const int outputChannels =  PyArray_DIMS(%(W)s)[0];
            const int inputChannels = PyArray_DIMS(%(V)s)[4];

            if (PyArray_DIMS(%(W)s)[4] != inputChannels)
            {
                PyErr_Format(PyExc_ValueError, "Conv3D: W operates on a %%ld channel image but the image has %%d channels. Overall shape of input: (%%ld,%%ld,%%ld,%%ld,%%ld)", (long)PyArray_DIMS(%(W)s)[4], inputChannels, (long)PyArray_DIMS(%(V)s)[0], (long)PyArray_DIMS(%(V)s)[1], (long)PyArray_DIMS(%(V)s)[2], (long)PyArray_DIMS(%(V)s)[3], (long)PyArray_DIMS(%(V)s)[4]);
                %(fail)s
            }

            if (PyArray_DIMS(%(b)s)[0] != outputChannels)
            {
                PyErr_Format(PyExc_ValueError, "Conv3D: b adds to a(n) %%ld channel output image but the output has %%d channels", (long)PyArray_DIMS(%(b)s)[0], outputChannels);
                %(fail)s
            }

{  //extra scope so error handler jumps don't cause errors
            const int filterHeight = PyArray_DIMS(%(W)s)[1];
            const int filterWidth = PyArray_DIMS(%(W)s)[2];
            const int filterDur = PyArray_DIMS(%(W)s)[3];
            const int vidHeight = PyArray_DIMS(%(V)s)[1];
            const int vidWidth = PyArray_DIMS(%(V)s)[2];
            const int vidDur = PyArray_DIMS(%(V)s)[3];
            if (vidHeight < filterHeight)
            {
                PyErr_Format(PyExc_ValueError, "W has a height of %%i but V is only %%i pixels tall",filterHeight,vidHeight);
                %(fail)s
            }

{ // extra scope so fail works

            if (vidWidth < filterWidth)
            {
                PyErr_Format(PyExc_ValueError, "W has a width of %%i but V is only %%i pixels wide",filterWidth,vidWidth);
                %(fail)s
            }

{ // extra scope so fail works

            if (vidDur < filterDur)
            {
                PyErr_Format(PyExc_ValueError, "W has a duration of %%i but V is only %%i pixels long",filterDur,vidDur);
                %(fail)s
            }

{ // extra scope so fail works

            //Read and check stride arguments
            const int dr = *(dtype_%(d)s*) PyArray_GETPTR1(%(d)s,0);
            const int dc = *(dtype_%(d)s*) PyArray_GETPTR1(%(d)s,1);
            const int dt = *(dtype_%(d)s*) PyArray_GETPTR1(%(d)s,2);

            if (dr <= 0 || dc <= 0 || dt <= 0)
            {
                PyErr_Format(PyExc_ValueError,"Conv3D: Strides must all be positive but are %%i, %%i, %%i",dr,dc,dt);
                %(fail)s
            }
{ // extra scope so fail works

            //Make correctly sized output
            const long long outputHeight = int( (vidHeight - filterHeight) / dr )+1;
            const long long outputWidth = int( (vidWidth - filterWidth) / dc )+1;
            const long long outputDur = int( (vidDur - filterDur) / dt ) +1;

            npy_intp dims[5];
            dims[0] = batchSize;
            dims[4] = outputChannels;
            dims[1] = outputHeight;
            dims[2] = outputWidth;
            dims[3] = outputDur;

            if(!(%(H)s) || PyArray_DIMS(%(H)s)[0]!=dims[0] ||
            PyArray_DIMS(%(H)s)[1]!=dims[1] ||
            PyArray_DIMS(%(H)s)[2]!=dims[2] ||
            PyArray_DIMS(%(H)s)[3]!=dims[3] ||
            PyArray_DIMS(%(H)s)[4]!=dims[4]){
                Py_XDECREF(%(H)s);
                %(H)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, PyArray_DESCR(%(V)s)->type_num);
                if (!(%(H)s)) {
                    PyErr_Format(PyExc_MemoryError,"Conv3D: Could not allocate output.");
                    %(fail)s
                }
            }
{ // extra scope so fail works

            #define ELEM_AT(x, i) * ( dtype_ ## x *) ( PyArray_BYTES(x) + (i) )

            const int ws0 = PyArray_STRIDES(%(W)s)[0];
            const int ws1 = PyArray_STRIDES(%(W)s)[1];
            const int ws2 = PyArray_STRIDES(%(W)s)[2];
            const int vs1 = PyArray_STRIDES(%(V)s)[1];
            const int ws4 = PyArray_STRIDES(%(W)s)[4];
            const int vs4 = PyArray_STRIDES(%(V)s)[4];
            const int ws3 = PyArray_STRIDES(%(W)s)[3];
            const int vs3 = PyArray_STRIDES(%(V)s)[3];
            const int vs2 = PyArray_STRIDES(%(V)s)[2];
            const int bs  = PyArray_STRIDES(%(b)s)[0];
            const int hs4 = PyArray_STRIDES(%(H)s)[4];

            // Compute H
            //H[i,j,x,y,t] = b_j + sum_k sum_l sum_m sum_z W[j,z,k,l,m] V[i,z, dr*r+k,dc*c+l,dt*t+m]
            //TODO: add special cases
            // ex: filterDur == 1 && batchSize == 1 && dt = 1  (for SFA)
            // ex: inputChannels == 1 t   float64t   dgemv_t   float32t   sgemv_s#   Unrecognized dtype for convolution s  
            if (inputChannels > 20 && outputChannels > 20 && ws4 == sizeof(ELEM_AT(%(W)s,0)))
            {
              //std::cout << "lots of channels special case code" << std::endl;
              #define blas_type dtype_ ## %(V)s
              const blas_type  constant_one = 1.0;
              char N = 'T';
              int ws0e = ws0 / sizeof(ELEM_AT(%(W)s,0));
              int vs4e = vs4 / sizeof(ELEM_AT(%(V)s,4));
              int hs4e = hs4 / sizeof(ELEM_AT(%(H)s,4));

                //special case code for the "lots of channels" case
                //uses a BLAS matrix vector multiply to compute the contribute for
                //all channels of an input pixel to all channels of an output pixel
                //simultaneously
              long long Hpos = 0;
              long long Vpos = 0;
              for (int i = 0; i < batchSize; i++) {
                    long long Hposi = Hpos;
                    long long Vposi = Vpos;


                    for (int r = 0;  r < outputHeight; r++) {
                      long long Hposr = Hpos;
                      long long Vposr = Vpos;
                      for (int c = 0; c < outputWidth; c++) {
                       long long Hposc = Hpos;
                       long long Vposc = Vpos;
                       for (int t = 0; t < outputDur; t++) {
                            long long Hpost = Hpos;
                            long long Vpost = Vpos;
                            //of the loops so far, j should be the innermost, because
                            //each loop through j visits the same elements of V
                            //this implies that the last index of H should be the j index
                            //since V and H should have the same format, this means
                            //z should be the last index in v, and therefore the innermost
                            //of the next set of for loops

                            int Wpos = 0;
                            int bPos = 0;


                            long long Hposj = Hpos;
                            for (int j = 0; j < outputChannels; j++) {
                                // H[i,r,c,t,j] = b[j]
                                ELEM_AT(%(H)s,Hposj) = ELEM_AT(%(b)s,bPos);
                                Hposj += hs4;
                                bPos += bs;
                            }

                            dtype_%(H)s * writePos = & ELEM_AT(%(H)s,Hpos);


                            for (int k =0; k < filterHeight; k++) {
                                  int Wposk = Wpos;
                                  long long Vposk = Vpos;
                                  for (int l = 0; l < filterWidth; l++) {
                                    int Wposl = Wpos;
                                    long long Vposl = Vpos;
                                    for (int m = 0; m < filterDur; m++) {

                                      //H[i,r,c,t,:] += N.dot(W[:,k,l,m,:],V[i,dr*r+k,dc*c+l,dt*t+m,:])


                                      //note: changing the weights so that outputChannels and inputChannels were the last two rather than
                                      //the first and last elements did not speed this up, even for extremely large input sizes

                                      %(gemv)s(&N, & inputChannels, & outputChannels,
                     &constant_one, & ELEM_AT( %(W)s , Wpos),& ws0e,
                     & ELEM_AT(%(V)s, Vpos),& vs4e, &constant_one,
                     writePos,& hs4e);

                                      Wpos  += ws3;
                                      Vpos  += vs3;
                                    } // close m
                                    Wpos = Wposl + ws2;
                                    Vpos = Vposl + vs2;
                                  } //close l
                                  Wpos = Wposk + PyArray_STRIDES(%(W)s)[1];
                                  Vpos = Vposk + PyArray_STRIDES(%(V)s)[1];
                                } //close k
                             Hpos = Hpost + PyArray_STRIDES(%(H)s)[3];
                             Vpos = Vpost + vs3 * dt;
                         } //close t
                         Hpos = Hposc + PyArray_STRIDES(%(H)s)[2];
                         Vpos = Vposc + vs2 * dc;
                       } //close c
                       Hpos = Hposr + PyArray_STRIDES(%(H)s)[1];
                       Vpos = Vposr + PyArray_STRIDES(%(V)s)[1] * dr;
                   } //closes r
                   Hpos = Hposi + PyArray_STRIDES(%(H)s)[0];
                   Vpos = Vposi + PyArray_STRIDES(%(V)s)[0];
              } //closes i


            } //closes "lots of channels" special case code
            else
s  
            {
              //General case code
              //std::cout << "general case code" << std::endl;
              long long Hpos = 0;
              long long Vpos = 0;
              for (int i = 0; i < batchSize; i++) {
                    long long Hposi = Hpos;
                    long long Vposi = Vpos;


                    for (int r = 0;  r < outputHeight; r++) {
                      long long Hposr = Hpos;
                      long long Vposr = Vpos;
                      for (int c = 0; c < outputWidth; c++) {
                       long long Hposc = Hpos;
                       long long Vposc = Vpos;
                       for (int t = 0; t < outputDur; t++) {
                            long long Hpost = Hpos;
                            long long Vpost = Vpos;
                            //of the loops so far, j should be the innermost, because
                            //each loop through j visits the same elements of V
                            //this implies that the last index of H should be the j index
                            //since V and H should have the same format, this means
                            //z should be the last index in v, and therefore the innermost
                            //of the next set of for loops

                            int Wpos = 0;
                            int bPos = 0;


                            for (int j = 0; j < outputChannels; j++) {


                                long long Hposj = Hpos;
                                long long Vposj = Vpos;
                                int Wposj = Wpos;

                                // H[i,r,c,t,j] = b[j]

                                dtype_%(H)s & writePos = ELEM_AT(%(H)s,Hpos);


                                writePos = ELEM_AT(%(b)s,bPos);


                                for (int k =0; k < filterHeight; k++) {
                                  int Wposk = Wpos;
                                  long long Vposk = Vpos;
                                  for (int l = 0; l < filterWidth; l++) {
                                    int Wposl = Wpos;
                                    long long Vposl = Vpos;
                                    for (int m = 0; m < filterDur; m++) {
                                      int Wposm = Wpos;
                                      long long Vposm = Vpos;
                                      for (int z = 0; z < inputChannels; z++) {
                                        //H[i,r,c,t,j] += W[j,z,k,l,m] * V[i,dr*r+k, dc*c+l, dt*t+m,z]


                                        writePos += ELEM_AT(%(W)s,Wpos) * ELEM_AT(%(V)s,Vpos);

                                        Wpos += ws4;
                                        Vpos += vs4;
                                      } // close z
                                      Wpos = Wposm + ws3;
                                      Vpos = Vposm + vs3;
                                    } // close m
                                    Wpos = Wposl + ws2;
                                    Vpos = Vposl + vs2;
                                  } //close l
                                  Wpos = Wposk + PyArray_STRIDES(%(W)s)[1];
                                  Vpos = Vposk + PyArray_STRIDES(%(V)s)[1];
                                } //close k


                              bPos += bs;
                              Wpos = Wposj + ws0;
                              Hpos = Hposj +  hs4;
                              Vpos = Vposj;
                              //std::cout << "incremented Wpos by " << ws0 << std::endl;
                              //std::cout << "incremented Hpos by " << hs4 << std::endl;
                             } //close j
                             Hpos = Hpost + PyArray_STRIDES(%(H)s)[3];
                             Vpos = Vpost + vs3 * dt;
                         } //close t
                         Hpos = Hposc + PyArray_STRIDES(%(H)s)[2];
                         Vpos = Vposc + vs2 * dc;
                       } //close c
                       Hpos = Hposr + PyArray_STRIDES(%(H)s)[1];
                       Vpos = Vposr + PyArray_STRIDES(%(V)s)[1] * dr;
                   } //closes r
                   Hpos = Hposi + PyArray_STRIDES(%(H)s)[0];
                   Vpos = Vposi + PyArray_STRIDES(%(V)s)[0];
              } //closes i
            } //closes general case code
}}}}}}} //extra scope so error handler jumps don't cross declarations
            ///////////// < /code generated by Conv3D >
        (   R   R   R   t   configt   blasR   R   t	   Exceptiont   valueR   t   render_stringt   locals(   R   R    t   nodenameR   R   t   subR   R   R   R   Ra   t   Ht
   codeSourcet   VVt   WVt   bvt   dvt   HVt   gemv(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   c_code   s$    

$		a
a(    (   t   __name__t
   __module__t   __doc__t	   __props__R   R!   R>   RA   RV   RW   RX   R\   R^   R`   Rv   (    (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyR
   .   s   			5							c         C` s   t  |  | | |  S(   s  
    3D "convolution" of multiple filters on a minibatch.

    (does not flip the kernel, moves kernel with a user specified stride)

    Parameters
    ----------
    V
        Visible unit, input.
        Dimensions: (batch, row, column, time, in channel).
    W
        Weights, filter.
        Dimensions: (out channel, row, column, time ,in channel).
    b
        Bias, shape == (W.shape[0],).
    d
        Strides when moving the filter over the input(dx, dy, dt).

    Notes
    -----
    The order of dimensions does not correspond to the one in `conv2d`.
    This is for optimization.

    The GPU implementation is very slow. You should use
    :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>` or
    :func:`conv3d_fft <theano.sandbox.cuda.fftconv.conv3d_fft>` for a
    GPU graph instead.

    See Also
    --------
    Someone made a script that shows how to swap the axes
    between both 3d convolution implementations in Theano. See
    the last `attachment <https://groups.google.com/d/msg/theano-users/1S9_bZgHxVw/0cQR9a4riFUJ>`_

(   t   _conv3D(   R   R   R   R   (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   conv3D(  s    $c         C` s  t  | j  d k s t  t  |  j  d k s6 t  t  | j  d k rg t | j  t sg t  n  t  |  d k s t  |  j d } | j d } |  j d } | j d | k r t d t | j d  d t |    n  | j d } | j d } | j d }	 |  j d }
 |  j d } |  j d } |
 | k sGt  | | k sYt  | |	 k skt  | \ } } } | d k st  | d k st  | d k st  t |
 | |  d } t | | |  d } t | |	 |  d } t j	 | | | | | f d	 |  j
 } xt d | j d  D]} xt d | j d  D]g} x^t d | j d  D]F} x=t d | j d  D]%} xt d | j d  D]} | | | | | | | | f <x t d |  D] } x t d |  D] } x t d |	  D] } x t d |  D] } | | | | | | f } |  | | d | | | d | | | d | | | f } | | | | | | f c | | 7<qWq	WqWqWqWqWqpWqSWq6W| S(
   Ni   i   i   i    i   s   W.shape[4] = s    but inputChannels = i   R   (   t   lenR-   t   AssertionErrort   printR   Rh   t   strt   intt   Nt   zerosR   R   (   R   R   R   R   t	   batchSizet   outputChannelst   inputChannelsRM   RO   RQ   RL   RN   RP   t   dxt   dyRI   t   outputHeightt   outputWidtht	   outputDurRn   t   it   jt   xt   yt   tt   kt   lt   mt   zt   wt   v(    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyR?   O  sV    .=G(   t
   __future__R    R   R   t   numpyR   t	   six.movesR   R   t   theano.tensorR   R   t   theano.tensor.blas_headersR   R   t   theano.tensor.blasR   t   theano.miscR   t   theano.gradientR	   t   OpR
   R{   R|   R?   (    (    (    s9   /tmp/pip-build-X4mzal/theano/theano/tensor/nnet/Conv3D.pyt   <module>   s   " 		'