Skip to content

Commit a568fdf

Browse files
committed
added regression tests for variable names functionality within bert module
1 parent 895e48c commit a568fdf

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

text_extensions_for_pandas/io/bert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,12 @@ def align_bert_tokens_to_corpus_tokens(
216216
"""
217217
if len(spans_df.index) == 0:
218218
return spans_df.copy()
219+
219220
overlaps_df = (
220221
spanner
221222
.overlap_join(spans_df[spans_df_token_col], corpus_toks_df[corpus_df_token_col],
222223
"span", "corpus_token")
223-
.merge(spans_df)
224+
.merge(spans_df,left_on='span',right_on=spans_df_token_col)
224225
)
225226
agg_df = (
226227
overlaps_df

text_extensions_for_pandas/io/test_bert.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,31 @@ def test_conll_to_bert(self):
198198
3 [15, 22): 'Failure' PER
199199
4 [23, 24): '(' <NA>"""))
200200

201+
## test with renamed fields
202+
without_embeddings_alt = without_embeddings.rename(columns={
203+
'span':'span-1',
204+
'ent_type':'ent_type-1'})
205+
first_df_alt = first_df.rename(columns={'span':'span-2'})
206+
aligned_toks_alt = align_bert_tokens_to_corpus_tokens(without_embeddings_alt,
207+
first_df_alt,
208+
spans_df_token_col='span-1',
209+
corpus_df_token_col='span-2',
210+
entity_type_col='ent_type-1'
211+
)
212+
print(str(aligned_toks_alt.iloc[:num_rows]))
213+
self.assertEqual(
214+
str(aligned_toks_alt.iloc[:num_rows]),
215+
# NOTE: Don't forget to add both sets of double-backslashes back in if you
216+
# copy-and-paste an updated version of the output below!
217+
textwrap.dedent("""\
218+
span ent_type-1
219+
0 [0, 3): 'Who' <NA>
220+
1 [4, 6): 'is' <NA>
221+
2 [7, 14): 'General' PER
222+
3 [15, 22): 'Failure' PER
223+
4 [23, 24): '(' <NA>"""))
224+
225+
201226
def test_seq_to_windows(self):
202227
for seqlen in range(1, 20):
203228
seq = np.arange(1, seqlen)

0 commit comments

Comments
 (0)