diff --git a/tests/aggregators-cols.vdj b/tests/aggregators-cols.vdj new file mode 100644 index 000000000..fbd17d043 --- /dev/null +++ b/tests/aggregators-cols.vdj @@ -0,0 +1,7 @@ +#!vd -p +{"sheet": "global", "col": null, "row": "disp_date_fmt", "longname": "set-option", "input": "%b %d, %Y", "keystrokes": "", "comment": null} +{"longname": "open-file", "input": "sample_data/test.jsonl", "keystrokes": "o"} +{"sheet": "test", "col": "key2", "row": "", "longname": "key-col", "input": "", "keystrokes": "!", "comment": "toggle current column as a key column"} +{"sheet": "test", "col": "key2", "row": "", "longname": "addcol-aggregate", "input": "count", "comment": "add column(s) with aggregator of rows grouped by key columns"} +{"sheet": "test", "col": "qty", "row": "", "longname": "type-float", "input": "", "keystrokes": "%", "comment": "set type of current column to float"} +{"sheet": "test", "col": "qty", "row": "", "longname": "addcol-aggregate", "input": "rank sum", "comment": "add column(s) with aggregator of rows grouped by key columns"} diff --git a/tests/golden/aggregators-cols.tsv b/tests/golden/aggregators-cols.tsv new file mode 100644 index 000000000..6c3f83e14 --- /dev/null +++ b/tests/golden/aggregators-cols.tsv @@ -0,0 +1,11 @@ +key2 key2_count key1 qty qty_rank qty_sum amt +foo 2 2016-01-01 11:00:00 1.00 1 31.00 + 0 2016-01-01 1:00 2.00 1 66.00 3 +baz 3 4.00 1 292.00 43.2 +#ERR 0 #ERR #ERR 1 0.00 #ERR #ERR +bar 2 2017-12-25 8:44 16.00 2 16.00 .3 +baz 3 32.00 2 292.00 3.3 + 0 2018-07-27 4:44 64.00 2 66.00 9.1 +bar 2 2018-07-27 16:44 1 16.00 +baz 3 2018-07-27 18:44 256.00 3 292.00 .01 +foo 2 2018-10-20 18:44 30.00 2 31.00 .01 diff --git a/visidata/aggregators.py b/visidata/aggregators.py index b34e9bcc4..cc730038f 100644 --- a/visidata/aggregators.py +++ b/visidata/aggregators.py @@ -3,8 +3,9 @@ import functools import collections import statistics +import itertools -from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData +from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData, SettableColumn from visidata import vd, anytype, vlen, asyncthread, wrapply, AttrDict, date, stacktrace, TypedExceptionWrapper vd.help_aggregators = '''# Choose Aggregators @@ -75,7 +76,7 @@ def aggregators_set(col, aggs): class Aggregator: - def __init__(self, name, type, funcValues=None, helpstr='foo'): + def __init__(self, name, type, funcValues=None, helpstr=''): 'Define aggregator `name` that calls funcValues(values)' self.type = type self.funcValues = funcValues # funcValues(values) @@ -91,6 +92,34 @@ def aggregate(self, col, rows): # wrap builtins so they can have a .type return None raise e +class ListAggregator(Aggregator): + '''A list aggregator is an aggregator that returns a list of values, generally + one value per input row, unlike ordinary aggregators that operate on rows + and return only a single value. + To implement a new list aggregator, subclass ListAggregator, + and override aggregate() and aggregate_list().''' + def __init__(self, name, type, helpstr='', listtype=None): + '''*listtype* determines the type of the column created by addcol_aggregate() + for list aggrs. If it is None, then the new column will match the type of the input column''' + super().__init__(name, type, helpstr=helpstr) + self.listtype = listtype + + def aggregate(self, col, rows) -> list: + '''Return a list, which can be shorter than *rows*, because it filters out nulls and errors. + Override in subclass.''' + vals = self.aggregate_list(col, rows) + # filter out nulls and errors + vals = [ v for v in vals if not col.sheet.isNullFunc()(v) ] + return vals + + def aggregate_list(self, col, row_group) -> list: + '''Return a list of results, which will be one result per input row. + *row_group* is an iterable that holds a "group" of rows to run the aggregator on. + rows in *row_group* are not necessarily in the same order they are in the sheet. + Override in subclass.''' + vals = [ col.getTypedValue(r) for r in row_group ] + return vals + @VisiData.api def aggregator(vd, name, funcValues, helpstr='', *, type=None): @@ -98,6 +127,14 @@ def aggregator(vd, name, funcValues, helpstr='', *, type=None): Use *type* to force type of aggregated column (default to use type of source column).''' vd.aggregators[name] = Aggregator(name, type, funcValues=funcValues, helpstr=helpstr) +@VisiData.api +def aggregator_list(vd, name, helpstr='', type=anytype, listtype=anytype): + '''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*. + Use *type* to force type of aggregated column (default to use type of source column). + Use *listtype* to force the type of the new column created by addcol-aggregate. + If *listtype* is None, it will match the type of the source column.''' + vd.aggregators[name] = ListAggregator(name, type, helpstr=helpstr, listtype=listtype) + ## specific aggregator implementations def mean(vals): @@ -146,10 +183,92 @@ def __init__(self, pct, helpstr=''): def aggregate(self, col, rows): return _percentile(sorted(col.getValues(rows)), self.pct/100, key=float) - def quantiles(q, helpstr): return [PercentileAggregator(round(100*i/q), helpstr) for i in range(1, q)] +class RankAggregator(ListAggregator): + ''' + Ranks start at 1, and each group's rank is 1 higher than the previous group. + When elements are tied in ranking, each of them gets the same rank. + ''' + def aggregate(self, col, rows) -> [int]: + return self.aggregate_list(col, rows) + + def aggregate_list(self, col, rows) -> [int]: + if not col.sheet.keyCols: + vd.error('ranking requires one or more key columns') + return None + return self.rank(col, rows) + + def rank(self, col, rows): + # compile row data, for each row a list of tuples: (group_key, rank_key, rownum) + rowdata = [(col.sheet.rowkey(r), col.getTypedValue(r), rownum) for rownum, r in enumerate(rows)] + # sort by row key and column value to prepare for grouping + try: + rowdata.sort() + except TypeError as e: + vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}') + rowvals = [] + #group by row key + for _, group in itertools.groupby(rowdata, key=lambda v: v[0]): + # within a group, the rows have already been sorted by col_val + group = list(group) + # rank each group individually + group_ranks = rank_sorted_iterable([col_val for _, col_val, rownum in group]) + rowvals += [(rownum, rank) for (_, _, rownum), rank in zip(group, group_ranks)] + # sort by unique rownum, to make rank results match the original row order + rowvals.sort() + rowvals = [ rank for rownum, rank in rowvals ] + return rowvals + +def rank_sorted_iterable(vals_sorted) -> [int]: + '''*vals_sorted* is an iterable whose elements form one group. + The iterable must already be sorted.''' + + ranks = [] + val_groups = itertools.groupby(vals_sorted) + for rank, (_, val_group) in enumerate(val_groups, 1): + for _ in val_group: + ranks.append(rank) + return ranks + +def aggregate_groups(sheet, col, rows, aggr) -> list: + '''Returns a list, containing the result of the aggregator applied to each row. + *col* is a column whose values determine each rows rank within a group. + *rows* is a list of visidata rows. + *aggr* is an Aggregator object. + ''' + def _key_progress(prog): + def identity(val): + prog.addProgress(1) + return val + return identity + + with Progress(gerund='ranking', total=4*sheet.nRows) as prog: + p = _key_progress(prog) # increment progress every time p() is called + # compile row data, for each row a list of tuples: (group_key, rank_key, rownum) + rowdata = [(sheet.rowkey(r), col.getTypedValue(r), p(rownum)) for rownum, r in enumerate(rows)] + # sort by row key and column value to prepare for grouping + try: + rowdata.sort(key=p) + except TypeError as e: + vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}') + rowvals = [] + #group by row key + for _, group in itertools.groupby(rowdata, key=lambda v: v[0]): + # within a group, the rows have already been sorted by col_val + group = list(group) + if isinstance(aggr, ListAggregator): # for list aggregators, each row gets its own value + aggr_vals = aggr.aggregate_list(col, [rows[rownum] for _, _, rownum in group]) + rowvals += [(rownum, v) for (_, _, rownum), v in zip(group, aggr_vals)] + else: # for normal aggregators, each row in the group gets the same value + aggr_val = aggr.aggregate(col, [rows[rownum] for _, _, rownum in group]) + rowvals += [(rownum, aggr_val) for _, _, rownum in group] + prog.addProgress(len(group)) + # sort by unique rownum, to make rank results match the original row order + rowvals.sort(key=p) + rowvals = [ v for rownum, v in rowvals ] + return rowvals vd.aggregator('min', min, 'minimum value') vd.aggregator('max', max, 'maximum value') @@ -160,8 +279,10 @@ def quantiles(q, helpstr): vd.aggregator('sum', vsum, 'sum of values') vd.aggregator('distinct', set, 'distinct values', type=vlen) vd.aggregator('count', lambda values: sum(1 for v in values), 'number of values', type=int) -vd.aggregator('list', list, 'list of values', type=anytype) -vd.aggregator('stdev', stdev, 'standard deviation of values', type=float) +vd.aggregator_list('list', 'list of values', type=anytype, listtype=None) +vd.aggregator('stdev', statistics.stdev, 'standard deviation of values', type=float) + +vd.aggregators['rank'] = RankAggregator('rank', anytype, helpstr='list of ranks, when grouping by key columns', listtype=int) vd.aggregators['q3'] = quantiles(3, 'tertiles (33/66th pctile)') vd.aggregators['q4'] = quantiles(4, 'quartiles (25/50/75th pctile)') @@ -252,9 +373,8 @@ def aggregator_choices(vd): @VisiData.api -def chooseAggregators(vd): +def chooseAggregators(vd, prompt = 'choose aggregators: '): '''Return a list of aggregator name strings chosen or entered by the user. User-entered names may be invalid.''' - prompt = 'choose aggregators: ' def _fmt_aggr_summary(match, row, trigger_key): formatted_aggrname = match.formatted.get('key', row.key) if match else row.key r = ' '*(len(prompt)-3) @@ -281,10 +401,50 @@ def _fmt_aggr_summary(match, row, trigger_key): vd.warning(f'aggregator does not exist: {aggr}') return aggrs +@Sheet.api +@asyncthread +def addcol_aggregate(sheet, col, aggrnames): + for aggrname in aggrnames: + aggrs = vd.aggregators.get(aggrname) + aggrs = aggrs if isinstance(aggrs, list) else [aggrs] + if not aggrs: continue + for aggr in aggrs: + rows = aggregate_groups(sheet, col, sheet.rows, aggr) + if isinstance(aggr, ListAggregator): + t = aggr.listtype or col.type + else: + t = aggr.type or col.type + c = SettableColumn(name=f'{col.name}_{aggr.name}', type=t) + sheet.addColumnAtCursor(c) + c.setValues(sheet.rows, *rows) + +@Sheet.api +@asyncthread +def addcol_sheetrank(sheet, rows): + ''' + Each row is ranked within its sheet. Rows are ordered by the + value of their key columns. + ''' + colname = f'{sheet.name}_sheetrank' + c = SettableColumn(name=colname, type=int) + sheet.addColumnAtCursor(c) + if not sheet.keyCols: + vd.error('ranking requires one or more key columns') + return None + rowkeys = [(sheet.rowkey(r), rownum) for rownum, r in enumerate(rows)] + rowkeys.sort() + ranks = rank_sorted_iterable([rowkey for rowkey, rownum in rowkeys]) + row_ranks = sorted(zip((rownum for _, rownum in rowkeys), ranks)) + row_ranks = [rank for rownum, rank in row_ranks] + c.setValues(sheet.rows, *row_ranks) + Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'Add aggregator to current column') Sheet.addCommand('z+', 'memo-aggregate', 'cursorCol.memo_aggregate(chooseAggregators(), selectedRows or rows)', 'memo result of aggregator over values in selected rows for current column') ColumnsSheet.addCommand('g+', 'aggregate-cols', 'addAggregators(selectedRows or source[0].nonKeyVisibleCols, chooseAggregators())', 'add aggregators to selected source columns') +Sheet.addCommand('', 'addcol-aggregate', 'addcol_aggregate(cursorCol, chooseAggregators(prompt="aggregator for groups: "))', 'add column(s) with aggregator of rows grouped by key columns') +Sheet.addCommand('', 'addcol-sheetrank', 'sheet.addcol_sheetrank(rows)', 'add column with the rank of each row based on its key columns') vd.addMenuItems(''' Column > Add aggregator > aggregate-col + Column > Add column > aggregate > addcol-aggregate ''') diff --git a/visidata/tests/test_commands.py b/visidata/tests/test_commands.py index d47a26050..5403f8ce5 100644 --- a/visidata/tests/test_commands.py +++ b/visidata/tests/test_commands.py @@ -116,6 +116,7 @@ def isTestableCommand(longname, cmdlist): 'sheet': '', 'col': 'Units', 'row': '5', + 'addcol-aggregate': 'max', } @pytest.mark.usefixtures('curses_setup')