Skip to content

Commit 3469a47

Browse files
lucyleeowadrinjalali
authored andcommitted
DOC Update plot_column_transformer to notebook style (scikit-learn#17028)
1 parent b5e8495 commit 3469a47

File tree

1 file changed

+100
-71
lines changed

1 file changed

+100
-71
lines changed

examples/compose/plot_column_transformer.py

+100-71
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,18 @@
33
Column Transformer with Heterogeneous Data Sources
44
==================================================
55
6-
Datasets can often contain components of that require different feature
7-
extraction and processing pipelines. This scenario might occur when:
6+
Datasets can often contain components that require different feature
7+
extraction and processing pipelines. This scenario might occur when:
88
9-
1. Your dataset consists of heterogeneous data types (e.g. raster images and
10-
text captions)
11-
2. Your dataset is stored in a Pandas DataFrame and different columns
9+
1. your dataset consists of heterogeneous data types (e.g. raster images and
10+
text captions),
11+
2. your dataset is stored in a :class:`pandas.DataFrame` and different columns
1212
require different processing pipelines.
1313
1414
This example demonstrates how to use
15-
:class:`sklearn.compose.ColumnTransformer` on a dataset containing
16-
different types of features. We use the 20-newsgroups dataset and compute
17-
standard bag-of-words features for the subject line and body in separate
18-
pipelines as well as ad hoc features on the body. We combine them (with
19-
weights) using a ColumnTransformer and finally train a classifier on the
20-
combined set of features.
21-
22-
The choice of features is not particularly helpful, but serves to illustrate
23-
the technique.
15+
:class:`~sklearn.compose.ColumnTransformer` on a dataset containing
16+
different types of features. The choice of features is not particularly
17+
helpful, but serves to illustrate the technique.
2418
"""
2519

2620
# Author: Matt Terry <[email protected]>
@@ -29,7 +23,7 @@
2923

3024
import numpy as np
3125

32-
from sklearn.base import BaseEstimator, TransformerMixin
26+
from sklearn.preprocessing import FunctionTransformer
3327
from sklearn.datasets import fetch_20newsgroups
3428
from sklearn.decomposition import TruncatedSVD
3529
from sklearn.feature_extraction import DictVectorizer
@@ -39,95 +33,130 @@
3933
from sklearn.compose import ColumnTransformer
4034
from sklearn.svm import LinearSVC
4135

36+
##############################################################################
37+
# 20 newsgroups dataset
38+
# ---------------------
39+
#
40+
# We will use the :ref:`20 newsgroups dataset <20newsgroups_dataset>`, which
41+
# comprises posts from newsgroups on 20 topics. This dataset is split
42+
# into train and test subsets based on messages posted before and after
43+
# a specific date. We will only use posts from 2 categories to speed up running
44+
# time.
4245

43-
class TextStats(TransformerMixin, BaseEstimator):
44-
"""Extract features from each document for DictVectorizer"""
46+
categories = ['sci.med', 'sci.space']
47+
X_train, y_train = fetch_20newsgroups(random_state=1,
48+
subset='train',
49+
categories=categories,
50+
remove=('footers', 'quotes'),
51+
return_X_y=True)
52+
X_test, y_test = fetch_20newsgroups(random_state=1,
53+
subset='test',
54+
categories=categories,
55+
remove=('footers', 'quotes'),
56+
return_X_y=True)
4557

46-
def fit(self, x, y=None):
47-
return self
58+
##############################################################################
59+
# Each feature comprises meta information about that post, such as the subject,
60+
# and the body of the news post.
4861

49-
def transform(self, posts):
50-
return [{'length': len(text),
51-
'num_sentences': text.count('.')}
52-
for text in posts]
62+
print(X_train[0])
5363

64+
##############################################################################
65+
# Creating transformers
66+
# ---------------------
67+
#
68+
# First, we would like a transformer that extracts the subject and
69+
# body of each post. Since this is a stateless transformation (does not
70+
# require state information from training data), we can define a function that
71+
# performs the data transformation then use
72+
# :class:`~sklearn.preprocessing.FunctionTransformer` to create a scikit-learn
73+
# transformer.
5474

55-
class SubjectBodyExtractor(TransformerMixin, BaseEstimator):
56-
"""Extract the subject & body from a usenet post in a single pass.
5775

58-
Takes a sequence of strings and produces a dict of sequences. Keys are
59-
`subject` and `body`.
60-
"""
61-
def fit(self, x, y=None):
62-
return self
76+
def subject_body_extractor(posts):
77+
# construct object dtype array with two columns
78+
# first column = 'subject' and second column = 'body'
79+
features = np.empty(shape=(len(posts), 2), dtype=object)
80+
for i, text in enumerate(posts):
81+
# temporary variable `_` stores '\n\n'
82+
headers, _, body = text.partition('\n\n')
83+
# store body text in second column
84+
features[i, 1] = body
6385

64-
def transform(self, posts):
65-
# construct object dtype array with two columns
66-
# first column = 'subject' and second column = 'body'
67-
features = np.empty(shape=(len(posts), 2), dtype=object)
68-
for i, text in enumerate(posts):
69-
headers, _, bod = text.partition('\n\n')
70-
features[i, 1] = bod
86+
prefix = 'Subject:'
87+
sub = ''
88+
# save text after 'Subject:' in first column
89+
for line in headers.split('\n'):
90+
if line.startswith(prefix):
91+
sub = line[len(prefix):]
92+
break
93+
features[i, 0] = sub
7194

72-
prefix = 'Subject:'
73-
sub = ''
74-
for line in headers.split('\n'):
75-
if line.startswith(prefix):
76-
sub = line[len(prefix):]
77-
break
78-
features[i, 0] = sub
95+
return features
7996

80-
return features
8197

98+
subject_body_transformer = FunctionTransformer(subject_body_extractor)
8299

83-
pipeline = Pipeline([
84-
# Extract the subject & body
85-
('subjectbody', SubjectBodyExtractor()),
100+
##############################################################################
101+
# We will also create a transformer that extracts the
102+
# length of the text and the number of sentences.
103+
104+
105+
def text_stats(posts):
106+
return [{'length': len(text),
107+
'num_sentences': text.count('.')}
108+
for text in posts]
109+
110+
111+
text_stats_transformer = FunctionTransformer(text_stats)
86112

87-
# Use ColumnTransformer to combine the features from subject and body
113+
##############################################################################
114+
# Classification pipeline
115+
# -----------------------
116+
#
117+
# The pipeline below extracts the subject and body from each post using
118+
# ``SubjectBodyExtractor``, producing a (n_samples, 2) array. This array is
119+
# then used to compute standard bag-of-words features for the subject and body
120+
# as well as text length and number of sentences on the body, using
121+
# ``ColumnTransformer``. We combine them, with weights, then train a
122+
# classifier on the combined set of features.
123+
124+
pipeline = Pipeline([
125+
# Extract subject & body
126+
('subjectbody', subject_body_transformer),
127+
# Use ColumnTransformer to combine the subject and body features
88128
('union', ColumnTransformer(
89129
[
90-
# Pulling features from the post's subject line (first column)
130+
# bag-of-words for subject (col 0)
91131
('subject', TfidfVectorizer(min_df=50), 0),
92-
93-
# Pipeline for standard bag-of-words model for body (second column)
132+
# bag-of-words with decomposition for body (col 1)
94133
('body_bow', Pipeline([
95134
('tfidf', TfidfVectorizer()),
96135
('best', TruncatedSVD(n_components=50)),
97136
]), 1),
98-
99-
# Pipeline for pulling ad hoc features from post's body
137+
# Pipeline for pulling text stats from post's body
100138
('body_stats', Pipeline([
101-
('stats', TextStats()), # returns a list of dicts
139+
('stats', text_stats_transformer), # returns a list of dicts
102140
('vect', DictVectorizer()), # list of dicts -> feature matrix
103141
]), 1),
104142
],
105-
106-
# weight components in ColumnTransformer
143+
# weight above ColumnTransformer features
107144
transformer_weights={
108145
'subject': 0.8,
109146
'body_bow': 0.5,
110147
'body_stats': 1.0,
111148
}
112149
)),
113-
114150
# Use a SVC classifier on the combined features
115151
('svc', LinearSVC(dual=False)),
116152
], verbose=True)
117153

118-
# limit the list of categories to make running this example faster.
119-
categories = ['alt.atheism', 'talk.religion.misc']
120-
X_train, y_train = fetch_20newsgroups(random_state=1,
121-
subset='train',
122-
categories=categories,
123-
remove=('footers', 'quotes'),
124-
return_X_y=True)
125-
X_test, y_test = fetch_20newsgroups(random_state=1,
126-
subset='test',
127-
categories=categories,
128-
remove=('footers', 'quotes'),
129-
return_X_y=True)
154+
##############################################################################
155+
# Finally, we fit our pipeline on the training data and use it to predict
156+
# topics for ``X_test``. Performance metrics of our pipeline are then printed.
130157

131158
pipeline.fit(X_train, y_train)
132159
y_pred = pipeline.predict(X_test)
133-
print(classification_report(y_test, y_pred))
160+
print('Classification report:\n\n{}'.format(
161+
classification_report(y_test, y_pred))
162+
)

0 commit comments

Comments
 (0)