|
563 | 563 | " 'num_topics': str(num_topics),\n",
|
564 | 564 | " 'feature_dim': str(vocabulary_size),\n",
|
565 | 565 | " 'mini_batch_size': str(num_documents_training),\n",
|
566 |
| - " 'alpha0': str(0.1),\n", |
| 566 | + " 'alpha0': str(1.0),\n", |
567 | 567 | " },\n",
|
568 | 568 | " 'InputDataConfig': [\n",
|
569 | 569 | " {\n",
|
|
1048 | 1048 | "cell_type": "markdown",
|
1049 | 1049 | "metadata": {},
|
1050 | 1050 | "source": [
|
1051 |
| - "Let's also compute and plot the distribution of L1-errors from **all** of the test documents" |
| 1051 | + "In the eyeball-norm these look quite comparable.\n", |
| 1052 | + "\n", |
| 1053 | + "Let's be more scientific about this. Below we compute and plot the distribution of L1-errors from **all** of the test documents. Note that we send a new payload of test documents to the inference endpoint and apply the appropriate permutation to the output." |
1052 | 1054 | ]
|
1053 | 1055 | },
|
1054 | 1056 | {
|
|
1061 | 1063 | "\n",
|
1062 | 1064 | "# create a payload containing all of the test documents and run inference again\n",
|
1063 | 1065 | "#\n",
|
1064 |
| - "# try switching between the test data set and a subset of the training data set\n", |
| 1066 | + "# TRY THIS:\n", |
| 1067 | + "# try switching between the test data set and a subset of the training\n", |
| 1068 | + "# data set. It is likely that LDA inference will perform better against\n", |
| 1069 | + "# the training set than the holdout test set.\n", |
1065 | 1070 | "#\n",
|
1066 |
| - "payload_documents = documents_test # Example 1\n", |
1067 |
| - "#payload_documents = documents_training[:600] # Example 2\n", |
| 1071 | + "payload_documents = documents_test # Example 1\n", |
| 1072 | + "known_topic_mixtures = topic_mixtures_test # Example 1\n", |
| 1073 | + "#payload_documents = documents_training[:600]; # Example 2\n", |
| 1074 | + "#known_topic_mixtures = topic_mixtures_training[:600] # Example 2\n", |
1068 | 1075 | "\n",
|
1069 | 1076 | "print('Invoking endpoint...\\n')\n",
|
1070 | 1077 | "payload = np2csv(documents_test)\n",
|
|
1078 | 1085 | "inferred_topic_mixtures_permuted = np.array([prediction['topic_mixture'] for prediction in results['predictions']])\n",
|
1079 | 1086 | "inferred_topic_mixtures = inferred_topic_mixtures_permuted[:,permutation]\n",
|
1080 | 1087 | "\n",
|
1081 |
| - "print('topics_mixtures_test.shape = {}'.format(topic_mixtures_test.shape))\n", |
| 1088 | + "print('known_topics_mixtures.shape = {}'.format(known_topic_mixtures.shape))\n", |
1082 | 1089 | "print('inferred_topics_mixtures_test.shape = {}\\n'.format(inferred_topic_mixtures.shape))"
|
1083 | 1090 | ]
|
1084 | 1091 | },
|
|
1088 | 1095 | "metadata": {},
|
1089 | 1096 | "outputs": [],
|
1090 | 1097 | "source": [
|
1091 |
| - "l1_errors = np.linalg.norm((inferred_topic_mixtures - topic_mixtures_test), 1, axis=1)\n", |
| 1098 | + "l1_errors = np.linalg.norm((inferred_topic_mixtures - known_topic_mixtures), 1, axis=1)\n", |
1092 | 1099 | "\n",
|
1093 | 1100 | "# plot the error freqency\n",
|
1094 | 1101 | "fig, ax_frequency = plt.subplots()\n",
|
|
1122 | 1129 | "fig.set_dpi(110)"
|
1123 | 1130 | ]
|
1124 | 1131 | },
|
| 1132 | + { |
| 1133 | + "cell_type": "markdown", |
| 1134 | + "metadata": {}, |
| 1135 | + "source": [ |
| 1136 | + "Machine learning algorithms are not perfect and the data above suggests this is true of SageMaker LDA. With more documents and some hyperparameter tuning we can obtain more accurate results against the known topic-mixtures.\n", |
| 1137 | + "\n", |
| 1138 | + "For now, let's just investigate the documents-topic mixture pairs that seem to do well as well as those that do not. Below we retreive a document and topic mixture corresponding to a small L1-error as well as one with a large L1-error." |
| 1139 | + ] |
| 1140 | + }, |
| 1141 | + { |
| 1142 | + "cell_type": "code", |
| 1143 | + "execution_count": null, |
| 1144 | + "metadata": {}, |
| 1145 | + "outputs": [], |
| 1146 | + "source": [ |
| 1147 | + "N = 6\n", |
| 1148 | + "\n", |
| 1149 | + "good_idx = (l1_errors < 0.1)\n", |
| 1150 | + "good_documents = payload_documents[good_idx][:N]\n", |
| 1151 | + "good_topic_mixtures = inferred_topic_mixtures[good_idx][:N]\n", |
| 1152 | + "\n", |
| 1153 | + "poor_idx = (l1_errors > 0.4)\n", |
| 1154 | + "poor_documents = payload_documents[poor_idx][:N]\n", |
| 1155 | + "poor_topic_mixtures = inferred_topic_mixtures[poor_idx][:N]" |
| 1156 | + ] |
| 1157 | + }, |
| 1158 | + { |
| 1159 | + "cell_type": "code", |
| 1160 | + "execution_count": null, |
| 1161 | + "metadata": {}, |
| 1162 | + "outputs": [], |
| 1163 | + "source": [ |
| 1164 | + "%matplotlib inline\n", |
| 1165 | + "\n", |
| 1166 | + "fig = plot_lda_topics(good_documents, 2, 3, topic_mixtures=good_topic_mixtures)\n", |
| 1167 | + "fig.suptitle('Documents With Accurate Inferred Topic-Mixtures')\n", |
| 1168 | + "fig.set_dpi(120)" |
| 1169 | + ] |
| 1170 | + }, |
| 1171 | + { |
| 1172 | + "cell_type": "code", |
| 1173 | + "execution_count": null, |
| 1174 | + "metadata": {}, |
| 1175 | + "outputs": [], |
| 1176 | + "source": [ |
| 1177 | + "%matplotlib inline\n", |
| 1178 | + "\n", |
| 1179 | + "fig = plot_lda_topics(poor_documents, 2, 3, topic_mixtures=poor_topic_mixtures)\n", |
| 1180 | + "fig.suptitle('Documents With Inaccurate Inferred Topic-Mixtures')\n", |
| 1181 | + "fig.set_dpi(120)" |
| 1182 | + ] |
| 1183 | + }, |
| 1184 | + { |
| 1185 | + "cell_type": "markdown", |
| 1186 | + "metadata": {}, |
| 1187 | + "source": [ |
| 1188 | + "In this example set the documents on which inference was not as accurate tend to have a denser topic-mixture. This makes sense when extrapolated to real-world datasets: it can be difficult to nail down which topics are represented in a document when the document uses words from a large subset of the vocabulary." |
| 1189 | + ] |
| 1190 | + }, |
1125 | 1191 | {
|
1126 | 1192 | "cell_type": "markdown",
|
1127 | 1193 | "metadata": {},
|
|
0 commit comments