|
5 | 5 | import pandas as pd
|
6 | 6 | import numpy as np
|
7 | 7 | from . ingredient import BaseIngredient, Ingredient, ProcedureResult
|
8 |
| -from .helpers import read_opt |
| 8 | +from .helpers import read_opt, mkfunc |
9 | 9 | from .. import config
|
10 | 10 | from .. import transformer
|
11 | 11 | import time
|
@@ -420,33 +420,71 @@ def align(to_align: BaseIngredient, base: Ingredient, *, result=None, **options)
|
420 | 420 | return ProcedureResult(result, to_replace, data=ing_data)
|
421 | 421 |
|
422 | 422 |
|
423 |
| -def groupby(ingredient: BaseIngredient, *, result=None, **options) -> ProcedureResult: |
| 423 | +def groupby(ingredient: BaseIngredient, *, result, **options) -> ProcedureResult: |
424 | 424 | """group ingredient data by column(s) and run aggregate function
|
425 | 425 |
|
426 | 426 | available options:
|
427 |
| - by: the column(s) to group, can be a list or a string |
428 |
| - aggregate: the function to aggregate. Default: sum |
429 |
| - """ |
| 427 | + groubby: the column(s) to group, can be a list or a string |
| 428 | + aggregate/transform/filter: the function to run. only one of them should be supplied. |
430 | 429 |
|
431 |
| - data = ingredient.get_data() |
432 |
| - by = options.pop('by') |
| 430 | + The function block should have below format: |
| 431 | +
|
| 432 | + aggregate: |
| 433 | + column1: funcname1 |
| 434 | + column2: funcname2 |
| 435 | +
|
| 436 | + or |
| 437 | +
|
| 438 | + aggregate: |
| 439 | + column: |
| 440 | + function: funcname |
| 441 | + param1: foo |
| 442 | + param2: bar |
| 443 | +
|
| 444 | + other columns not mentioned will be dropped. |
| 445 | + """ |
433 | 446 |
|
434 | 447 | logger.info("groupby: " + ingredient.ingred_id)
|
435 | 448 |
|
436 |
| - try: |
437 |
| - agg = options.pop('aggregate') |
438 |
| - except KeyError: |
439 |
| - logger.warning("no aggregate function found, assuming sum()") |
440 |
| - agg = 'sum' |
| 449 | + data = ingredient.get_data() |
| 450 | + by = options.pop('groupby') |
441 | 451 |
|
442 |
| - for k, df in data.items(): |
443 |
| - df = df.groupby(by=by).agg({k: agg}) |
444 |
| - newkey = ','.join(df.index.names) |
445 |
| - data[k] = df.reset_index() |
| 452 | + # only one of aggregate/transform/filter should be in options. |
| 453 | + assert len(list(options.keys())) == 1 |
| 454 | + comp_type = list(options.keys())[0] |
| 455 | + assert comp_type in ['aggregate', 'transform', 'filter'] |
446 | 456 |
|
447 |
| - if not result: |
448 |
| - result = ingredient.ingred_id + '-agg' |
449 |
| - return ProcedureResult(result, newkey, data=data) |
| 457 | + if comp_type == 'aggregate': # only aggregate should change the key of ingredient |
| 458 | + if isinstance(by, list): |
| 459 | + newkey = ','.join(by) |
| 460 | + else: |
| 461 | + newkey = by |
| 462 | + by = [by] |
| 463 | + logger.debug("changing the key to: " + str(newkey)) |
| 464 | + else: |
| 465 | + newkey = ingredient.key |
| 466 | + by = [by] |
| 467 | + |
| 468 | + newdata = dict() |
| 469 | + |
| 470 | + if comp_type == 'aggregate': |
| 471 | + for k, func in options[comp_type].items(): |
| 472 | + func = mkfunc(func) |
| 473 | + newdata[k] = data[k].groupby(by=by).agg({k: func}).reset_index() |
| 474 | + if comp_type == 'transform': |
| 475 | + for k, func in options[comp_type].items(): |
| 476 | + func = mkfunc(func) |
| 477 | + df = data[k].set_index(ingredient.key_to_list()) |
| 478 | + levels = [df.index.names.index(x) for x in by] |
| 479 | + newdata[k] = df.groupby(level=levels)[k].transform(func).reset_index() |
| 480 | + if comp_type == 'filter': |
| 481 | + for k, func in options[comp_type].items(): |
| 482 | + func = mkfunc(func) |
| 483 | + df = data[k].set_index(ingredient.key_to_list()) |
| 484 | + levels = [df.index.names.index(x) for x in by] |
| 485 | + newdata[k] = df.groupby(level=levels)[k].filter(func).reset_index() |
| 486 | + |
| 487 | + return ProcedureResult(result, newkey, data=newdata) |
450 | 488 |
|
451 | 489 |
|
452 | 490 | def accumulate(ingredient: BaseIngredient, *, result=None, **options) -> ProcedureResult:
|
@@ -563,3 +601,9 @@ def extract_concepts(*ingredients: List[BaseIngredient],
|
563 | 601 | return ProcedureResult(result, 'concept', data=concepts.reset_index())
|
564 | 602 |
|
565 | 603 |
|
| 604 | +def trend_bridge(ingredient: BaseIngredient, result, **options) -> ProcedureResult: |
| 605 | + """run trend bridge on ingredients |
| 606 | + """ |
| 607 | + from ..transformer import trend_bridge as tb |
| 608 | + |
| 609 | + raise NotImplementedError('') |
0 commit comments