diff --git a/dynamic_programming/k_means_clustering_tensorflow.py_tf b/dynamic_programming/k_means_clustering_tensorflow.py similarity index 99% rename from dynamic_programming/k_means_clustering_tensorflow.py_tf rename to dynamic_programming/k_means_clustering_tensorflow.py index 4fbcedeaa0dc..b19ffb64c5e9 100644 --- a/dynamic_programming/k_means_clustering_tensorflow.py_tf +++ b/dynamic_programming/k_means_clustering_tensorflow.py @@ -1,5 +1,6 @@ -import tensorflow as tf from random import shuffle + +import tensorflow as tf from numpy import array diff --git a/machine_learning/lstm/lstm_prediction.py_tf b/machine_learning/lstm/lstm_prediction.py similarity index 100% rename from machine_learning/lstm/lstm_prediction.py_tf rename to machine_learning/lstm/lstm_prediction.py diff --git a/neural_network/gan.py_tf b/neural_network/gan.py similarity index 99% rename from neural_network/gan.py_tf rename to neural_network/gan.py index deb062c48dc7..6eeb50975ad7 100644 --- a/neural_network/gan.py_tf +++ b/neural_network/gan.py @@ -2,7 +2,8 @@ import matplotlib.pyplot as plt import numpy as np from sklearn.utils import shuffle -import input_data + +from neural_network import input_data random_numer = 42 diff --git a/neural_network/input_data.py_tf b/neural_network/input_data.py similarity index 85% rename from neural_network/input_data.py_tf rename to neural_network/input_data.py index 0e22ac0bcda5..4937508d389d 100644 --- a/neural_network/input_data.py_tf +++ b/neural_network/input_data.py @@ -23,11 +23,9 @@ import os import numpy -from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import random_seed +from six.moves import urllib +from tensorflow.python.framework import dtypes, random_seed from tensorflow.python.platform import gfile from tensorflow.python.util.deprecation import deprecated @@ -46,16 +44,16 @@ def _read32(bytestream): def _extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. - Args: - f: A file object that can be passed into a gzip reader. + Args: + f: A file object that can be passed into a gzip reader. - Returns: - data: A 4D uint8 numpy array [index, y, x, depth]. + Returns: + data: A 4D uint8 numpy array [index, y, x, depth]. - Raises: - ValueError: If the bytestream does not start with 2051. + Raises: + ValueError: If the bytestream does not start with 2051. - """ + """ print("Extracting", f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) @@ -86,17 +84,17 @@ def _dense_to_one_hot(labels_dense, num_classes): def _extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. - Args: - f: A file object that can be passed into a gzip reader. - one_hot: Does one hot encoding for the result. - num_classes: Number of classes for the one hot encoding. + Args: + f: A file object that can be passed into a gzip reader. + one_hot: Does one hot encoding for the result. + num_classes: Number of classes for the one hot encoding. - Returns: - labels: a 1D uint8 numpy array. + Returns: + labels: a 1D uint8 numpy array. - Raises: - ValueError: If the bystream doesn't start with 2049. - """ + Raises: + ValueError: If the bystream doesn't start with 2049. + """ print("Extracting", f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) @@ -115,8 +113,8 @@ def _extract_labels(f, one_hot=False, num_classes=10): class _DataSet: """Container class for a _DataSet (deprecated). - THIS CLASS IS DEPRECATED. - """ + THIS CLASS IS DEPRECATED. + """ @deprecated( None, @@ -135,21 +133,21 @@ def __init__( ): """Construct a _DataSet. - one_hot arg is used only if fake_data is true. `dtype` can be either - `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into - `[0, 1]`. Seed arg provides for convenient deterministic testing. - - Args: - images: The images - labels: The labels - fake_data: Ignore inages and labels, use fake data. - one_hot: Bool, return the labels as one hot vectors (if True) or ints (if - False). - dtype: Output image dtype. One of [uint8, float32]. `uint8` output has - range [0,255]. float32 output has range [0,1]. - reshape: Bool. If True returned images are returned flattened to vectors. - seed: The random seed to use. - """ + one_hot arg is used only if fake_data is true. `dtype` can be either + `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into + `[0, 1]`. Seed arg provides for convenient deterministic testing. + + Args: + images: The images + labels: The labels + fake_data: Ignore inages and labels, use fake data. + one_hot: Bool, return the labels as one hot vectors (if True) or ints (if + False). + dtype: Output image dtype. One of [uint8, float32]. `uint8` output has + range [0,255]. float32 output has range [0,1]. + reshape: Bool. If True returned images are returned flattened to vectors. + seed: The random seed to use. + """ seed1, seed2 = random_seed.get_seed(seed) # If op level seed is not set, use whatever graph level seed is returned numpy.random.seed(seed1 if seed is None else seed2) @@ -250,14 +248,14 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True): def _maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. - Args: - filename: string, name of the file in the directory. - work_directory: string, path to working directory. - source_url: url to download from if file doesn't exist. + Args: + filename: string, name of the file in the directory. + work_directory: string, path to working directory. + source_url: url to download from if file doesn't exist. - Returns: - Path to resulting file. - """ + Returns: + Path to resulting file. + """ if not gfile.Exists(work_directory): gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) @@ -293,10 +291,8 @@ def fake(): validation = fake() test = fake() return _Datasets(train=train, validation=validation, test=test) - if not source_url: # empty string check source_url = DEFAULT_SOURCE_URL - train_images_file = "train-images-idx3-ubyte.gz" train_labels_file = "train-labels-idx1-ubyte.gz" test_images_file = "t10k-images-idx3-ubyte.gz" @@ -307,30 +303,26 @@ def fake(): ) with gfile.Open(local_file, "rb") as f: train_images = _extract_images(f) - local_file = _maybe_download( train_labels_file, train_dir, source_url + train_labels_file ) with gfile.Open(local_file, "rb") as f: train_labels = _extract_labels(f, one_hot=one_hot) - local_file = _maybe_download( test_images_file, train_dir, source_url + test_images_file ) with gfile.Open(local_file, "rb") as f: test_images = _extract_images(f) - local_file = _maybe_download( test_labels_file, train_dir, source_url + test_labels_file ) with gfile.Open(local_file, "rb") as f: test_labels = _extract_labels(f, one_hot=one_hot) - if not 0 <= validation_size <= len(train_images): raise ValueError( - f"Validation size should be between 0 and {len(train_images)}. Received: {validation_size}." + f"Validation size should be between 0 " + f"and {len(train_images)}. Received: {validation_size}." ) - validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:]