
Xc           @` sP  d  Z  d d l m Z d d l m Z d d l m Z d d l Z d d l Z d d l Z d d l Z d d l	 Z
 d d l m Z d d l Z d d l
 Z d d l m Z e j j Z e j d e d	  e j d
 e d  e j d e d  e j d d d  e j d d d  e j d d d  e j d d d  e j d d d  e j d d d  e j d  d! d"  e j d# d! d$  e j d% d& d'  e j d( e d)  e j d* d! d+  e j d, d! d-  e j d. d/ d0  e j Z d1 e f d2     YZ d3 e f d4     YZ e d5  Z d6   Z  e! d7 k rLe j j"   n  d S(8   s(  Multi-threaded word2vec mini-batched skip-gram model.

Trains the model described in:
(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
ICLR 2013.
http://arxiv.org/abs/1301.3781
This model does traditional minibatching.

The key ops used are:
* placeholder for feeding in tensors for each example.
* embedding_lookup for fetching rows from the embedding matrix.
* sigmoid_cross_entropy_with_logits to calculate the loss.
* GradientDescentOptimizer for optimizing the loss.
* skipgram custom op that does input processing.
i    (   t   absolute_import(   t   division(   t   print_functionN(   t   xrange(   t   gen_word2vect	   save_paths4   Directory to write the model and training summaries.t
   train_datasL   Training text file. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.t	   eval_datas   File consisting of analogies of four tokens.embedding 2 - embedding 1 + embedding 3 should be close to embedding 4.E.g. https://word2vec.googlecode.com/svn/trunk/questions-words.txt.t   embedding_sizei   s   The embedding dimension size.t   epochs_to_traini   sR   Number of epochs to train. Each epoch processes the training data once completely.t   learning_rateg?s   Initial learning rate.t   num_neg_samplesid   s&   Negative samples per training example.t
   batch_sizei   sE   Number of training examples processed per step (size of a minibatch).t   concurrent_stepsi   s(   The number of concurrent training steps.t   window_sizei   sH   The number of words to predict to the left and right of the target word.t	   min_countsO   The minimum number of word occurrences for it to be included in the vocabulary.t	   subsamplegMbP?s   Subsample threshold for word occurrence. Words that appear with higher frequency will be randomly down-sampled. Set to 0 to disable.t   interactives   If true, enters an IPython interactive session to play with the trained model. E.g., try model.analogy('france', 'paris', 'russia') and model.nearby(['proton', 'elephant', 'maxwell']t   statistics_intervals!   Print statistics every n seconds.t   summary_intervalsQ   Save training summary to file every n seconds (rounded up to statistics interval.t   checkpoint_intervaliX  sc   Checkpoint the model (i.e. save the parameters) every n seconds (rounded up to statistics interval.t   Optionsc           B` s   e  Z d  Z d   Z RS(   s#   Options used by our word2vec model.c         C` s   t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j	 |  _	 t  j
 |  _
 t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j |  _ t  j |  _ d  S(   N(   t   FLAGSR   t   emb_dimR   R   t   num_samplesR
   R	   R   R   R   R   R   R   R   R   R   R   (   t   self(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   __init__g   s    (   t   __name__t
   __module__t   __doc__R   (    (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR   d   s   t   Word2Vecc           B` s   e  Z d  Z d   Z d   Z d   Z d   Z d   Z d   Z d   Z	 d   Z
 d	   Z d
   Z d   Z d   Z d   Z d d  Z RS(   s   Word2Vec model (Skipgram).c         C` sP   | |  _  | |  _ i  |  _ g  |  _ |  j   |  j   |  j   |  j   d  S(   N(   t   _optionst   _sessiont   _word2idt   _id2wordt   build_grapht   build_eval_grapht
   save_vocabt   _read_analogies(   R   t   optionst   session(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR      s    				


c      	   C` s)  g  } d } t  |  j j d   } x | D] } | j d  rF q+ n  | j   j   j d  } g  | D] } |  j j | j    ^ qh } d | k s t
 |  d k r | d 7} q+ | j t j |   q+ WWd QXt d |  j j  t d	 t
 |   t d
 |  t j | d t j |  _ d S(   s   Reads through the analogy question file.

    Returns:
      questions: a [n, 4] numpy array containing the analogy question's
                 word ids.
      questions_skipped: questions skipped due to unknown words.
    i    t   rbt   :t    i   i   Ns   Eval analogy file: s   Questions: s	   Skipped: t   dtype(   t   openR   R   t
   startswitht   stript   lowert   splitR!   t   gett   Nonet   lent   appendt   npt   arrayt   printt   int32t   _analogy_questions(   R   t	   questionst   questions_skippedt	   analogy_ft   linet   wordst   wt   ids(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR&      s    + c         C` s  |  j  } d | j } t j t j | j | j g | |  d d } | |  _ t j t j | j | j g  d d } t j t j | j g  d d } t j d d d |  _ t j	 t j
 | d t j | j d	 g  } t j j d
 | d d	 d | j d t d | j d d d | j j    \ }	 }
 }
 t j j | |  } t j j | |  } t j j | |  } t j j | |	  } t j j | |	  } t j t j | |  d	  | } t j	 | | j g  } t j | | d t | } | | f S(   s%   Build the graph for the forward pass.g      ?t   namet   embt   sm_w_tt   sm_bi    t   global_stepR,   i   t   true_classest   num_truet   num_sampledt   uniquet	   range_maxt
   distortiong      ?t   unigramst   transpose_b(   R   R   t   tft   Variablet   random_uniformt
   vocab_sizet   _embt   zerosRF   t   reshapet   castt   int64R   t   nnt   fixed_unigram_candidate_samplerR   t   Truet   vocab_countst   tolistt   embedding_lookupt
   reduce_sumt   mult   matmul(   R   t   examplest   labelst   optst
   init_widthRC   RD   RE   t   labels_matrixt   sampled_idst   _t   example_embt   true_wt   true_bt	   sampled_wt	   sampled_bt   true_logitst   sampled_b_vect   sampled_logits(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   forward   sF    				$			"	c         C` sl   |  j  } t j j | t j |   } t j j | t j |   } t j |  t j |  | j } | S(   s!   Build the graph for the NCE loss.(   R   RO   RX   t!   sigmoid_cross_entropy_with_logitst	   ones_liket
   zeros_likeR^   R   (   R   Rm   Ro   Rc   t	   true_xentt   sampled_xentt   nce_loss_tensor(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   nce_loss  s    			c         C` s   |  j  } t | j | j  } | j t j d d t j |  j t j	  |  } | |  _
 t j j |  } | j | d |  j d | j } | |  _ d S(   s.   Build the graph to optimize the loss function.g-C6?g      ?RF   t   gate_gradientsN(   R   t   floatt   words_per_epochR	   R
   RO   t   maximumRV   t   _wordst   float32t   _lrt   traint   GradientDescentOptimizert   minimizeRF   t	   GATE_NONEt   _train(   R   t   lossRc   t   words_to_traint   lrt	   optimizerR   (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   optimize  s    	'		c         C` sz  t  j d t  j  } t  j d t  j  } t  j d t  j  } t  j j |  j d  } t  j | |  } t  j | |  } t  j | |  } | | | } t  j | | d t }	 t  j j	 |	 d  \ }
 } t  j d t  j  } t  j | |  } t  j | | d t } t  j j	 | t
 d |  j j   \ } } | |  _ | |  _ | |  _ | |  _ | |  _ | |  _ | |  _ d S(   s   Build the eval graph.R,   i   RN   i   i  N(   RO   t   placeholderR9   RX   t   l2_normalizeRS   t   gatherR`   RZ   t   top_kt   minR   RR   t
   _analogy_at
   _analogy_bt
   _analogy_ct   _analogy_pred_idxt   _nearby_wordt   _nearby_valt   _nearby_idx(   R   t	   analogy_at	   analogy_bt	   analogy_ct   nembt   a_embt   b_embt   c_embt   targett   distRg   t   pred_idxt   nearby_wordt
   nearby_embt   nearby_distt
   nearby_valt
   nearby_idx(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR$   %  s,    						c         C` s  |  j  } t j d | j d | j d | j d | j d | j  \ } } } |  _ |  _	 } } |  j
 j | | | g  \ | _ | _ | _ t | j  | _ t d | j  t d | j d d	  t d
 | j  | |  _ | |  _ | j |  _ x* t |  j  D] \ } } | |  j | <qW|  j | |  \ }	 }
 |  j |	 |
  } t j d |  | |  _ |  j |  t j   j   t j j   |  _  d S(   s#   Build the graph for the full model.t   filenameR   R   R   R   s   Data file: s   Vocab size: i   s    + UNKs   Words per epoch: s   NCE lossN(!   R   t   word2vect   skipgramR   R   R   R   R   t   _epochR|   R    t   runt   vocab_wordsR[   Rz   R4   RR   R8   t	   _examplest   _labelsR"   t	   enumerateR!   Rp   Rw   RO   t   scalar_summaryt   _lossR   t   initialize_all_variablesR   t   Savert   saver(   R   Rc   R?   t   countsRz   Ra   Rb   t   iR@   Rm   Ro   R   (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR#   Z  s.    				'-			c      
   C` s   |  j  } t t j j | j d  d  R } xH t | j  D]7 } | j d t	 j
 j | j |  | j | f  q= WWd QXd S(   s;   Save the vocabulary to a file so the model can be reloaded.s	   vocab.txtR@   s   %s %d
N(   R   R-   t   ost   patht   joinR   R   RR   t   writeRO   t   compatt   as_textR   R[   (   R   Rc   t   fR   (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR%   z  s
    	$c         C` s`   |  j  j |  j g  \ } x> t r[ |  j  j |  j |  j g  \ } } | | k r Pq q Wd  S(   N(   R    R   R   RZ   R   (   R   t   initial_epochRg   t   epoch(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   _train_thread_body  s
    	$c         C` s3  |  j  } |  j j |  j |  j g  \ } } t j   } t j j | j	 d |  j j
 } g  } xC t | j  D]2 } t j d |  j  } | j   | j |  qp W| t j   d }	 }
 } d } xKt rt j | j  |  j j |  j |  j |  j |  j |  j g  \ } } } } } t j   } | | | |	 | |
 }	 }
 } t d | | | | | f d d t j j   | | | j k r|  j j |  } | j | |  | } n  | | | j k r |  j  j! |  j | j	 d d | j" t#  | } n  | | k r Pq q Wx | D] } | j$   qW| S(	   s   Train the model.t	   graph_defR   i    s>   Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0ft   endt    t   modelRF   (%   R   R    R   R   R|   RO   t   merge_all_summariesR   t   SummaryWriterR   R   R   R   t	   threadingt   ThreadR   t   startR5   t   timeRZ   t   sleepR   RF   R   R~   R8   t   syst   stdoutt   flushR   t   add_summaryR   R   t   savet   astypet   intR   (   R   Rc   R   t   initial_wordst
   summary_opt   summary_writert   workersRg   t   tt
   last_wordst	   last_timet   last_summary_timet   last_checkpoint_timeR   t   stepR   R?   R   t   nowt   ratet   summary_str(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR     sH    	$
		6		c         C` sp   |  j  j |  j g i | d d  d f |  j 6| d d  d f |  j 6| d d  d f |  j 6 \ } | S(   s0   Predict the top 4 answers for analogy questions.Ni    i   i   (   R    R   R   R   R   R   (   R   t   analogyt   idx(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   _predict  s
    #c   	      C` s'  d } |  j  j d } d } x | | k  r | d } |  j  | |  d d  f } |  j |  } | } x t | j d  D]x } xo t d  D]a } | | | f | | d f k r | d 7} Pq | | | f | | d d  f k r q q Pq Wq} Wq Wt   t d | | | d | f  d S(	   s0   Evaluate analogy questions and reports accuracy.i    i	  Ni   i   i   s   Eval %4d/%d accuracy = %4.1f%%g      Y@(   R:   t   shapeR   R   R8   (	   R   t   correctt   totalR   t   limitt   subR   t   questiont   j(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   eval  s&    
 
&c   	      C` s   t  j g  | | | g D] } |  j j | d  ^ q g  } |  j |  } xQ g  | d d d  f D] } |  j | ^ qf D] } | | | | g k r} | Sq} Wd S(   s%   Predict word w3 as in w0:w1 vs w2:w3.i    Nt   unknown(   R6   R7   R!   R2   R   R"   (	   R   t   w0t   w1t   w2R@   t   widR   R   t   c(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR     s    =7i   c   
      C` s   t  j g  | D] } |  j j | d  ^ q  } |  j j |  j |  j g i | |  j 6 \ } } x t	 t
 |   D]s } t d | |  xX t | | d |  f | | d |  f  D]' \ } }	 t d |  j | |	 f  q Wqu Wd S(   s.   Prints out nearby words given a list of words.i    s)   
%s
=====================================Ns   %-20s %6.4f(   R6   R7   R!   R2   R    R   R   R   R   R   R4   R8   t   zipR"   (
   R   R?   t   numt   xRA   t   valsR   R   t   neighbort   distance(    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   nearby  s    1	(<(   R   R   R   R   R&   Rp   Rw   R   R$   R#   R%   R   R   R   R   R   R   (    (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyR      s   	
		A			5	 			,					c         C` sR   d d  l  } i  } |  r( | j |   n  | j t    | j d g  d |  d  S(   Ni    t   argvt   user_ns(   t   IPythont   updatet   globalst   start_ipython(   t   local_nsR   R   (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   _start_shell  s    c         C` s   t  j s t  j s t  j r8 t d  t j d  n  t   } t j	   j
    t j    } t | |  } x+ t | j  D] }  | j   | j   q W| j j | t j j | j d  d | j t  j r t t    n  Wd QXWd QXd S(   s   Train a word2vec model.s;   --train_data --eval_data and --save_path must be specified.i   s
   model.ckptRF   N(   R   R   R   R   R8   R   t   exitR   RO   t   Grapht
   as_defaultt   SessionR   R   R	   R   R   R   R   R   R   R   RF   R   R   t   locals(   Rg   Rc   R(   R   (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   main  s    
	 

	t   __main__(#   R   t
   __future__R    R   R   R   R   R   R   t   tensorflow.python.platformt
   tensorflowt	   six.movesR   t   numpyR6   RO   t   tensorflow.models.embeddingR   R   t   appt   flagst   DEFINE_stringR3   t   DEFINE_integert   DEFINE_floatt   DEFINE_booleant   FalseR   t   objectR   R   R   R  R   R   (    (    (    sf   /tmp/pip-build-UG86a1/tensorflow/tensorflow-0.6.0.data/purelib/tensorflow/models/embedding/word2vec.pyt   <module>   sf   	: R
	