-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Implemented write_spmatrix_to_sparse_tensor #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
da9a2b5
a7aed45
572b0f2
00da404
a97f7a6
24e16b2
41e0ad4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
import sys | ||
|
||
import numpy as np | ||
from scipy.sparse import issparse | ||
|
||
from sagemaker.amazon.record_pb2 import Record | ||
|
||
|
@@ -64,6 +65,24 @@ def _write_label_tensor(resolved_type, record, scalar): | |
record.label["values"].float32_tensor.values.extend([scalar]) | ||
|
||
|
||
def _write_keys_tensor(resolved_type, record, vector): | ||
if resolved_type == "Int32": | ||
record.features["values"].int32_tensor.keys.extend(vector) | ||
elif resolved_type == "Float64": | ||
record.features["values"].float64_tensor.keys.extend(vector) | ||
elif resolved_type == "Float32": | ||
record.features["values"].float32_tensor.keys.extend(vector) | ||
|
||
|
||
def _write_shape(resolved_type, record, scalar): | ||
if resolved_type == "Int32": | ||
record.features["values"].int32_tensor.shape.extend([scalar]) | ||
elif resolved_type == "Float64": | ||
record.features["values"].float64_tensor.shape.extend([scalar]) | ||
elif resolved_type == "Float32": | ||
record.features["values"].float32_tensor.shape.extend([scalar]) | ||
|
||
|
||
def write_numpy_to_dense_tensor(file, array, labels=None): | ||
"""Writes a numpy array to a dense tensor""" | ||
|
||
|
@@ -89,6 +108,46 @@ def write_numpy_to_dense_tensor(file, array, labels=None): | |
_write_recordio(file, record.SerializeToString()) | ||
|
||
|
||
def write_numpy_to_sparse_tensor(file, array, labels=None): | ||
"""Writes a numpy array to a dense tensor""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spmatrix not numpy array |
||
|
||
if not issparse(array): | ||
raise TypeError("Array must be sparse") | ||
|
||
# Validate shape of array and labels, resolve array and label types | ||
if not len(array.shape) == 2: | ||
raise ValueError("Array must be a Matrix") | ||
if labels is not None: | ||
if not len(labels.shape) == 1: | ||
raise ValueError("Labels must be a Vector") | ||
if labels.shape[0] not in array.shape: | ||
raise ValueError("Label shape {} not compatible with array shape {}".format( | ||
labels.shape, array.shape)) | ||
resolved_label_type = _resolve_type(labels.dtype) | ||
resolved_type = _resolve_type(array.dtype) | ||
|
||
csr_array = array.tocsr() | ||
n_rows, n_cols = csr_array.shape | ||
|
||
record = Record() | ||
for row_idx in range(n_rows): | ||
record.Clear() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason this isn't done at the end of the for-loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Full analogy to the https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/amazon/common.py#L84 |
||
row = csr_array.getrow(row_idx) | ||
# Write values | ||
_write_feature_tensor(resolved_type, record, row.data) | ||
# Write keys | ||
_write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) | ||
|
||
# Write labels | ||
if labels is not None: | ||
_write_label_tensor(resolved_label_type, record, labels[row_idx]) | ||
|
||
# Write shape | ||
_write_shape(resolved_type, record, n_cols) | ||
|
||
_write_recordio(file, record.SerializeToString()) | ||
|
||
|
||
def read_records(file): | ||
"""Eagerly read a collection of amazon Record protobuf objects from file.""" | ||
records = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be write_spmatrix_to_sparse_tensor ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, fixed.