Skip to content

Commit

Permalink
[aggr-] add rank aggregator, add cmd addcol-aggregate
Browse files Browse the repository at this point in the history
Also adds a command addcol-sheetrank.
  • Loading branch information
midichef committed Jul 29, 2024
1 parent 6c7e178 commit 8078fb6
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 7 deletions.
7 changes: 7 additions & 0 deletions tests/aggregators-cols.vdj
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!vd -p
{"sheet": "global", "col": null, "row": "disp_date_fmt", "longname": "set-option", "input": "%b %d, %Y", "keystrokes": "", "comment": null}
{"longname": "open-file", "input": "sample_data/test.jsonl", "keystrokes": "o"}
{"sheet": "test", "col": "key2", "row": "", "longname": "key-col", "input": "", "keystrokes": "!", "comment": "toggle current column as a key column"}
{"sheet": "test", "col": "key2", "row": "", "longname": "addcol-aggregate", "input": "count", "comment": "add column(s) with aggregator of rows grouped by key columns"}
{"sheet": "test", "col": "qty", "row": "", "longname": "type-float", "input": "", "keystrokes": "%", "comment": "set type of current column to float"}
{"sheet": "test", "col": "qty", "row": "", "longname": "addcol-aggregate", "input": "rank sum", "comment": "add column(s) with aggregator of rows grouped by key columns"}
11 changes: 11 additions & 0 deletions tests/golden/aggregators-cols.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
key2 key2_count key1 qty qty_rank qty_sum amt
foo 2 2016-01-01 11:00:00 1.00000 1 31.0
0 2016-01-01 1:00 2.00000 1 66.0 3
baz 3 4.00000 1 292.0 43.2
#ERR 0 #ERR #ERR 1 0 #ERR #ERR
bar 2 2017-12-25 8:44 16.00000 2 16.0 .3
baz 3 32.00000 2 292.0 3.3
0 2018-07-27 4:44 64.00000 2 66.0 9.1
bar 2 2018-07-27 16:44 1 16.0
baz 3 2018-07-27 18:44 256.00000 3 292.0 .01
foo 2 2018-10-20 18:44 30.00000 2 31.0 .01
174 changes: 167 additions & 7 deletions visidata/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import functools
import collections
import statistics
import itertools

from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData
from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData, SettableColumn
from visidata import vd, anytype, vlen, asyncthread, wrapply, AttrDict, date, stacktrace, TypedExceptionWrapper

vd.help_aggregators = '''# Choose Aggregators
Expand Down Expand Up @@ -75,7 +76,7 @@ def aggregators_set(col, aggs):


class Aggregator:
def __init__(self, name, type, funcValues=None, helpstr='foo'):
def __init__(self, name, type, funcValues=None, helpstr=''):
'Define aggregator `name` that calls funcValues(values)'
self.type = type
self.funcValues = funcValues # funcValues(values)
Expand All @@ -91,13 +92,49 @@ def aggregate(self, col, rows): # wrap builtins so they can have a .type
return None
raise e

class ListAggregator(Aggregator):
'''A list aggregator is an aggregator that returns a list of values, generally
one value per input row, unlike ordinary aggregators that operate on rows
and return only a single value.
To implement a new list aggregator, subclass ListAggregator,
and override aggregate() and aggregate_list().'''
def __init__(self, name, type, helpstr='', listtype=None):
'''*listtype* determines the type of the column created by addcol_aggregate()
for list aggrs. If it is None, then the new column will match the type of the input column'''
super().__init__(name, type, helpstr=helpstr)
self.listtype = listtype

def aggregate(self, col, rows) -> list:
'''Return a list, which can be shorter than *rows*, because it filters out nulls and errors.
Override in subclass.'''
vals = self.aggregate_list(col, rows)
# filter out nulls and errors
vals = [ v for v in vals if not col.sheet.isNullFunc()(v) ]
return vals

def aggregate_list(self, col, row_group) -> list:
'''Return a list of results, which will be one result per input row.
*row_group* is an iterable that holds a "group" of rows to run the aggregator on.
rows in *row_group* are not necessarily in the same order they are in the sheet.
Override in subclass.'''
vals = [ col.getTypedValue(r) for r in row_group ]
return vals


@VisiData.api
def aggregator(vd, name, funcValues, helpstr='', *, type=None):
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
Use *type* to force type of aggregated column (default to use type of source column).'''
vd.aggregators[name] = Aggregator(name, type, funcValues=funcValues, helpstr=helpstr)

@VisiData.api
def aggregator_list(vd, name, helpstr='', type=anytype, listtype=anytype):
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
Use *type* to force type of aggregated column (default to use type of source column).
Use *listtype* to force the type of the new column created by addcol-aggregate.
If *listtype* is None, it will match the type of the source column.'''
vd.aggregators[name] = ListAggregator(name, type, helpstr=helpstr, listtype=listtype)

## specific aggregator implementations

def mean(vals):
Expand Down Expand Up @@ -146,10 +183,92 @@ def __init__(self, pct, helpstr=''):
def aggregate(self, col, rows):
return _percentile(sorted(col.getValues(rows)), self.pct/100, key=float)


def quantiles(q, helpstr):
return [PercentileAggregator(round(100*i/q), helpstr) for i in range(1, q)]

class RankAggregator(ListAggregator):
'''
Ranks start at 1, and each group's rank is 1 higher than the previous group.
When elements are tied in ranking, each of them gets the same rank.
'''
def aggregate(self, col, rows) -> [int]:
return self.aggregate_list(col, rows)

def aggregate_list(self, col, rows) -> [int]:
if not col.sheet.keyCols:
vd.error('ranking requires one or more key columns')
return None
return self.rank(col, rows)

def rank(self, col, rows):
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
rowdata = [(col.sheet.rowkey(r), col.getTypedValue(r), rownum) for rownum, r in enumerate(rows)]
# sort by row key and column value to prepare for grouping
try:
rowdata.sort()
except TypeError as e:
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
rowvals = []
#group by row key
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
# within a group, the rows have already been sorted by col_val
group = list(group)
# rank each group individually
group_ranks = rank_sorted_iterable([col_val for _, col_val, rownum in group])
rowvals += [(rownum, rank) for (_, _, rownum), rank in zip(group, group_ranks)]
# sort by unique rownum, to make rank results match the original row order
rowvals.sort()
rowvals = [ rank for rownum, rank in rowvals ]
return rowvals

def rank_sorted_iterable(vals_sorted) -> list[int]:
'''*vals_sorted* is an iterable whose elements form one group.
The iterable must already be sorted.'''

ranks = []
val_groups = itertools.groupby(vals_sorted)
for rank, (_, val_group) in enumerate(val_groups, 1):
for _ in val_group:
ranks.append(rank)
return ranks

def aggregate_groups(sheet, col, rows, aggr) -> list:
'''Returns a list, containing the result of the aggregator applied to each row.
*col* is a column whose values determine each rows rank within a group.
*rows* is a list of visidata rows.
*aggr* is an Aggregator object.
'''
def _key_progress(prog):
def identity(val):
prog.addProgress(1)
return val
return identity

with Progress(gerund='ranking', total=4*sheet.nRows) as prog:
p = _key_progress(prog) # increment progress every time p() is called
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
rowdata = [(sheet.rowkey(r), col.getTypedValue(r), p(rownum)) for rownum, r in enumerate(rows)]
# sort by row key and column value to prepare for grouping
try:
rowdata.sort(key=p)
except TypeError as e:
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
rowvals = []
#group by row key
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
# within a group, the rows have already been sorted by col_val
group = list(group)
if isinstance(aggr, ListAggregator): # for list aggregators, each row gets its own value
aggr_vals = aggr.aggregate_list(col, [rows[rownum] for _, _, rownum in group])
rowvals += [(rownum, v) for (_, _, rownum), v in zip(group, aggr_vals)]
else: # for normal aggregators, each row in the group gets the same value
aggr_val = aggr.aggregate(col, [rows[rownum] for _, _, rownum in group])
rowvals += [(rownum, aggr_val) for _, _, rownum in group]
prog.addProgress(len(group))
# sort by unique rownum, to make rank results match the original row order
rowvals.sort(key=p)
rowvals = [ v for rownum, v in rowvals ]
return rowvals

vd.aggregator('min', min, 'minimum value')
vd.aggregator('max', max, 'maximum value')
Expand All @@ -160,8 +279,10 @@ def quantiles(q, helpstr):
vd.aggregator('sum', vsum, 'sum of values')
vd.aggregator('distinct', set, 'distinct values', type=vlen)
vd.aggregator('count', lambda values: sum(1 for v in values), 'number of values', type=int)
vd.aggregator('list', list, 'list of values', type=anytype)
vd.aggregator('stdev', stdev, 'standard deviation of values', type=float)
vd.aggregator_list('list', 'list of values', type=anytype, listtype=None)
vd.aggregator('stdev', statistics.stdev, 'standard deviation of values', type=float)

vd.aggregators['rank'] = RankAggregator('rank', anytype, helpstr='list of ranks, when grouping by key columns', listtype=int)

vd.aggregators['q3'] = quantiles(3, 'tertiles (33/66th pctile)')
vd.aggregators['q4'] = quantiles(4, 'quartiles (25/50/75th pctile)')
Expand Down Expand Up @@ -252,9 +373,8 @@ def aggregator_choices(vd):


@VisiData.api
def chooseAggregators(vd):
def chooseAggregators(vd, prompt = 'choose aggregators: '):
'''Return a list of aggregator name strings chosen or entered by the user. User-entered names may be invalid.'''
prompt = 'choose aggregators: '
def _fmt_aggr_summary(match, row, trigger_key):
formatted_aggrname = match.formatted.get('key', row.key) if match else row.key
r = ' '*(len(prompt)-3)
Expand All @@ -281,10 +401,50 @@ def _fmt_aggr_summary(match, row, trigger_key):
vd.warning(f'aggregator does not exist: {aggr}')
return aggrs

@Sheet.api
@asyncthread
def addcol_aggregate(sheet, col, aggrnames):
for aggrname in aggrnames:
aggrs = vd.aggregators.get(aggrname)
aggrs = aggrs if isinstance(aggrs, list) else [aggrs]
if not aggrs: continue
for aggr in aggrs:
rows = aggregate_groups(sheet, col, sheet.rows, aggr)
if isinstance(aggr, ListAggregator):
t = aggr.listtype or col.type
else:
t = aggr.type or col.type
c = SettableColumn(name=f'{col.name}_{aggr.name}', type=t)
sheet.addColumnAtCursor(c)
c.setValues(sheet.rows, *rows)

@Sheet.api
@asyncthread
def addcol_sheetrank(sheet, rows):
'''
Each row is ranked within its sheet. Rows are ordered by the
value of their key columns.
'''
colname = f'{sheet.name}_sheetrank'
c = SettableColumn(name=colname, type=int)
sheet.addColumnAtCursor(c)
if not sheet.keyCols:
vd.error('ranking requires one or more key columns')
return None
rowkeys = [(sheet.rowkey(r), rownum) for rownum, r in enumerate(rows)]
rowkeys.sort()
ranks = rank_sorted_iterable([rowkey for rowkey, rownum in rowkeys])
row_ranks = sorted(zip((rownum for _, rownum in rowkeys), ranks))
row_ranks = [rank for rownum, rank in row_ranks]
c.setValues(sheet.rows, *row_ranks)

Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'Add aggregator to current column')
Sheet.addCommand('z+', 'memo-aggregate', 'cursorCol.memo_aggregate(chooseAggregators(), selectedRows or rows)', 'memo result of aggregator over values in selected rows for current column')
ColumnsSheet.addCommand('g+', 'aggregate-cols', 'addAggregators(selectedRows or source[0].nonKeyVisibleCols, chooseAggregators())', 'add aggregators to selected source columns')
Sheet.addCommand('', 'addcol-aggregate', 'addcol_aggregate(cursorCol, chooseAggregators(prompt="aggregator for groups: "))', 'add column(s) with aggregator of rows grouped by key columns')
Sheet.addCommand('', 'addcol-sheetrank', 'sheet.addcol_sheetrank(rows)', 'add column with the rank of each row based on its key columns')

vd.addMenuItems('''
Column > Add aggregator > aggregate-col
Column > Add column > aggregate > addcol-aggregate
''')
1 change: 1 addition & 0 deletions visidata/tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def isTestableCommand(longname, cmdlist):
'sheet': '',
'col': 'Units',
'row': '5',
'addcol-aggregate': 'max',
}

@pytest.mark.usefixtures('curses_setup')
Expand Down

0 comments on commit 8078fb6

Please sign in to comment.