Skip to content

Commit 39d5831

Browse files
committed
Updates for DataFrame.replace deprecation
See pandas-dev/pandas#57734
1 parent cee518f commit 39d5831

File tree

1 file changed

+88
-33
lines changed

1 file changed

+88
-33
lines changed

utils/encode-results.py

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,101 @@
22
import json
33
from pathlib import Path
44
from argparse import ArgumentParser
5+
from dataclasses import dataclass
6+
from multiprocessing import Pool, Queue
57

8+
import numpy as np
69
import pandas as pd
710

8-
def traverse(values, start, rev=False):
9-
for i in enumerate(values, start):
10-
if rev:
11-
i = tuple(reversed(i))
12-
yield i
11+
from mylib import Logger
12+
13+
@dataclass
14+
class Encoding:
15+
group: str
16+
codes: dict
17+
df: pd.DataFrame
18+
19+
def values(self):
20+
yield (self.group, self.codes)
21+
22+
@dataclass
23+
class DataGroup:
24+
group: str
25+
df: pd.DataFrame
26+
27+
def __str__(self):
28+
cols = ', '.join(self.df.columns)
29+
return f'{self.group}: {cols}'
30+
31+
def encode(self, start):
32+
categories = pd.Categorical(self.df.unstack())
33+
34+
codes = np.where(categories.codes < 0, pd.NA, categories.codes + start)
35+
values = np.hsplit(codes, len(self.df.columns))
36+
to_replace = dict(zip(self.df.columns, values))
37+
38+
codes = dict(enumerate(categories.categories, start))
39+
df = self.df.assign(**to_replace)
40+
41+
return Encoding(self.group, codes, df)
42+
43+
class EncodingRecorder:
44+
def __init__(self, output):
45+
self.output = output
46+
self.encodings = {}
47+
48+
def __enter__(self):
49+
self.encodings.clear()
50+
return self
51+
52+
def __exit__(self, exc_type, exc_value, traceback):
53+
with self.output.open('w') as fp:
54+
json.dump(self.encodings, fp, indent=2)
55+
56+
def push(self, encoding):
57+
self.encodings.update(encoding.values())
58+
59+
def func(incoming, outgoing, args):
60+
while True:
61+
data = incoming.get()
62+
Logger.info(data)
63+
outgoing.put(data.encode(args.start))
1364

1465
if __name__ == '__main__':
1566
arguments = ArgumentParser()
1667
arguments.add_argument('--save-encodings', type=Path)
1768
arguments.add_argument('--start', type=int, default=1)
69+
arguments.add_argument('--workers', type=int)
1870
args = arguments.parse_args()
1971

20-
df = pd.read_csv(sys.stdin)
21-
columns = {
22-
'prompt': (
23-
'instruction',
24-
),
25-
'model': (
26-
'generator_1',
27-
'generator_2',
28-
'preference',
29-
),
30-
}
31-
encodings = {}
32-
to_replace = {}
33-
34-
for (k, v) in columns.items():
35-
values = (df
36-
.filter(items=v)
37-
.dropna()
38-
.unstack()
39-
.unique())
40-
encodings[k] = dict(traverse(values, args.start))
41-
factors = dict(traverse(values, args.start, True))
42-
to_replace.update({x: factors for x in v})
43-
df = df.replace(to_replace=to_replace)
44-
df.to_csv(sys.stdout, index=False)
45-
46-
with args.save_encodings.open('w') as fp:
47-
json.dump(encodings, fp, indent=2)
72+
incoming = Queue()
73+
outgoing = Queue()
74+
initargs = (
75+
outgoing,
76+
incoming,
77+
args,
78+
)
79+
80+
with Pool(args.workers, func, initargs):
81+
groups = {
82+
'prompt': (
83+
'instruction',
84+
),
85+
'model': (
86+
'generator_1',
87+
'generator_2',
88+
'preference',
89+
),
90+
}
91+
df = pd.read_csv(sys.stdin)
92+
93+
for (k, v) in groups.items():
94+
data = DataGroup(k, df.filter(items=v))
95+
outgoing.put(data)
96+
97+
with EncodingRecorder(args.save_encodings) as recorder:
98+
for _ in range(len(groups)):
99+
encoding = incoming.get()
100+
recorder.push(encoding)
101+
df.update(encoding.df)
102+
df.to_csv(sys.stdout, index=False)

0 commit comments

Comments
 (0)