Skip to content

[WIP] restore .py file extensions to the tensorflow files #4331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf
from random import shuffle

import tensorflow as tf
from numpy import array


Expand Down
3 changes: 2 additions & 1 deletion neural_network/gan.py_tf → neural_network/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils import shuffle
import input_data

from neural_network import input_data

random_numer = 42

Expand Down
96 changes: 44 additions & 52 deletions neural_network/input_data.py_tf → neural_network/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@
import os

import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from six.moves import urllib
from tensorflow.python.framework import dtypes, random_seed
from tensorflow.python.platform import gfile
from tensorflow.python.util.deprecation import deprecated

Expand All @@ -46,16 +44,16 @@ def _read32(bytestream):
def _extract_images(f):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single letter variable names make the code look old school. Would it be possible upgrade to more self-documenting variable names and add type hints?

"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].

Args:
f: A file object that can be passed into a gzip reader.
Args:
f: A file object that can be passed into a gzip reader.

Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].

Raises:
ValueError: If the bytestream does not start with 2051.
Raises:
ValueError: If the bytestream does not start with 2051.

"""
"""
print("Extracting", f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
Expand Down Expand Up @@ -86,17 +84,17 @@ def _dense_to_one_hot(labels_dense, num_classes):
def _extract_labels(f, one_hot=False, num_classes=10):
"""Extract the labels into a 1D uint8 numpy array [index].

Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.
Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.

Returns:
labels: a 1D uint8 numpy array.
Returns:
labels: a 1D uint8 numpy array.

Raises:
ValueError: If the bystream doesn't start with 2049.
"""
Raises:
ValueError: If the bystream doesn't start with 2049.
"""
print("Extracting", f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
Expand All @@ -115,8 +113,8 @@ def _extract_labels(f, one_hot=False, num_classes=10):
class _DataSet:
"""Container class for a _DataSet (deprecated).

THIS CLASS IS DEPRECATED.
"""
THIS CLASS IS DEPRECATED.
"""

@deprecated(
None,
Expand All @@ -135,21 +133,21 @@ def __init__(
):
"""Construct a _DataSet.

one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`. Seed arg provides for convenient deterministic testing.

Args:
images: The images
labels: The labels
fake_data: Ignore inages and labels, use fake data.
one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
False).
dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
range [0,255]. float32 output has range [0,1].
reshape: Bool. If True returned images are returned flattened to vectors.
seed: The random seed to use.
"""
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`. Seed arg provides for convenient deterministic testing.

Args:
images: The images
labels: The labels
fake_data: Ignore inages and labels, use fake data.
one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
False).
dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
range [0,255]. float32 output has range [0,1].
reshape: Bool. If True returned images are returned flattened to vectors.
seed: The random seed to use.
"""
seed1, seed2 = random_seed.get_seed(seed)
# If op level seed is not set, use whatever graph level seed is returned
numpy.random.seed(seed1 if seed is None else seed2)
Expand Down Expand Up @@ -250,14 +248,14 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
def _maybe_download(filename, work_directory, source_url):
"""Download the data from source url, unless it's already here.

Args:
filename: string, name of the file in the directory.
work_directory: string, path to working directory.
source_url: url to download from if file doesn't exist.
Args:
filename: string, name of the file in the directory.
work_directory: string, path to working directory.
source_url: url to download from if file doesn't exist.

Returns:
Path to resulting file.
"""
Returns:
Path to resulting file.
"""
if not gfile.Exists(work_directory):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
Expand Down Expand Up @@ -293,10 +291,8 @@ def fake():
validation = fake()
test = fake()
return _Datasets(train=train, validation=validation, test=test)

if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL

train_images_file = "train-images-idx3-ubyte.gz"
train_labels_file = "train-labels-idx1-ubyte.gz"
test_images_file = "t10k-images-idx3-ubyte.gz"
Expand All @@ -307,30 +303,26 @@ def fake():
)
with gfile.Open(local_file, "rb") as f:
train_images = _extract_images(f)

local_file = _maybe_download(
train_labels_file, train_dir, source_url + train_labels_file
)
with gfile.Open(local_file, "rb") as f:
train_labels = _extract_labels(f, one_hot=one_hot)

local_file = _maybe_download(
test_images_file, train_dir, source_url + test_images_file
)
with gfile.Open(local_file, "rb") as f:
test_images = _extract_images(f)

local_file = _maybe_download(
test_labels_file, train_dir, source_url + test_labels_file
)
with gfile.Open(local_file, "rb") as f:
test_labels = _extract_labels(f, one_hot=one_hot)

if not 0 <= validation_size <= len(train_images):
raise ValueError(
f"Validation size should be between 0 and {len(train_images)}. Received: {validation_size}."
f"Validation size should be between 0 "
f"and {len(train_images)}. Received: {validation_size}."
Comment on lines +323 to +324
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Validation size should be between 0 "
f"and {len(train_images)}. Received: {validation_size}."
"Validation size should be between 0 and "
f"{len(train_images)}. Received: {validation_size}."

)

validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
Expand Down