Skip to content

Commit 631feb6

Browse files
Merge pull request #728 from nextstrain/branch-labels
Allow users to specify arbitrary branch & clade labels
2 parents 0050882 + dd318ba commit 631feb6

21 files changed

+716
-104
lines changed

CHANGES.md

+4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
* curate: Allow custom metadata delimiters with the new `--metadata-delimiters` flag. [#1196][] (@victorlin)
1616
* Bump the default recursion limit to 10,000. Users can continue to override this limit with the environment variable `AUGUR_RECURSION_LIMIT`. [#1200][] (@joverlee521)
1717

18+
* clades, export v2: Clade labels + coloring keys are now definable via arguments to augur clades allowing pipelines to use multiple invocations of augur clades resulting in multiple sets of colors and branch labels. How labels are stored in the (intermediate) node-data JSON files has changed. This should be fully backwards compatible for pipelines using augur commands, however custom scripts may need updating. PR [#728][] (@jameshadfield)
19+
20+
1821
### Bug fixes
1922

2023
* filter, frequencies, refine, parse: Previously, ambiguous dates in the future had a limit of today's date imposed on the upper value but not the lower value. It is now imposed on the lower value as well. [#1171][] (@victorlin)
2124
* refine: `--year-bounds` was ignored in versions 9.0.0 through 20.0.0. It now works. [#1136][] (@victorlin)
2225

26+
[#728]: https://github.com/nextstrain/augur/pull/728
2327
[#812]: https://github.com/nextstrain/augur/pull/812
2428
[#1136]: https://github.com/nextstrain/augur/issues/1136
2529
[#1152]: https://github.com/nextstrain/augur/pull/1152

augur/clades.py

+137-45
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
"""
22
Assign clades to nodes in a tree based on amino-acid or nucleotide signatures.
3+
4+
Nodes which are members of a clade are stored via
5+
<OUTPUT_NODE_DATA> → nodes → <node_name> → clade_membership
6+
and if this file is used in `augur export v2` these will automatically become a coloring.
7+
8+
The basal nodes of each clade are also given a branch label which is stored via
9+
<OUTPUT_NODE_DATA> → branches → <node_name> → labels → clade.
10+
11+
The keys "clade_membership" and "clade" are customisable via command line arguments.
312
"""
413

514
import sys
@@ -9,8 +18,12 @@
918
from collections import defaultdict
1019
import networkx as nx
1120
from itertools import islice
21+
from .errors import AugurError
22+
from argparse import SUPPRESS
1223
from .utils import get_parent_name_by_child_name_for_tree, read_node_data, write_json, get_json_name
1324

25+
UNASSIGNED = 'unassigned'
26+
1427
def read_in_clade_definitions(clade_file):
1528
'''
1629
Reads in tab-seperated file that defines clades by amino acid or nucleotide mutations
@@ -111,10 +124,13 @@ def read_in_clade_definitions(clade_file):
111124
if clade != root
112125
}
113126

127+
if not len(clades.keys()):
128+
raise AugurError(f"No clades were defined in {clade_file}")
129+
114130
return clades
115131

116132

117-
def is_node_in_clade(clade_alleles, node, ref):
133+
def is_node_in_clade(clade_alleles, node, root_sequence):
118134
'''
119135
Determines whether a node matches the clade definition based on sequence
120136
For any condition, will first look in mutations stored in node.sequences,
@@ -123,11 +139,13 @@ def is_node_in_clade(clade_alleles, node, ref):
123139
Parameters
124140
----------
125141
clade_alleles : list
126-
list of clade defining alleles
142+
list of clade defining alleles (typically supplied from the input TSV)
127143
node : Bio.Phylo.BaseTree.Clade
128144
node to check, assuming sequences (as mutations) are attached to node
129-
ref : str or list
130-
positions
145+
node.sequences specifies nucleotides/codons which are newly observed on this node
146+
i.e. they are the result of a mutation observed on the branch leading to this node
147+
root_sequence : dict
148+
{geneName: observed root sequence (list)}
131149
132150
Returns
133151
-------
@@ -139,21 +157,39 @@ def is_node_in_clade(clade_alleles, node, ref):
139157
for gene, pos, clade_state in clade_alleles:
140158
if gene in node.sequences and pos in node.sequences[gene]:
141159
state = node.sequences[gene][pos]
142-
elif ref and gene in ref:
143-
state = ref[gene][pos]
160+
elif root_sequence and gene in root_sequence:
161+
try:
162+
state = root_sequence[gene][pos]
163+
except IndexError:
164+
raise AugurError(f"A clade definition specifies {{{gene},{pos+1},{clade_state}}} which \
165+
is beyond the bounds of the supplied root sequence for {gene} (length {len(root_sequence[gene])})")
144166
else:
145167
state = ''
146168

147169
conditions.append(state==clade_state)
148170

149171
return all(conditions)
150172

173+
def ensure_no_multiple_mutations(all_muts):
174+
multiples = []
175+
176+
for name,node in all_muts.items():
177+
nt_positions = [int(mut[1:-1])-1 for mut in node.get('muts', [])]
178+
if len(set(nt_positions))!=len(nt_positions):
179+
multiples.append(f"Node {name} (nuc)")
180+
for gene in node.get('aa_muts', {}):
181+
aa_positions = [int(mut[1:-1])-1 for mut in node['aa_muts'][gene]]
182+
if len(set(aa_positions))!=len(aa_positions):
183+
multiples.append(f"Node {name} ({gene})")
184+
185+
if multiples:
186+
raise AugurError(f"Multiple mutations at the same position on a single branch were found: {', '.join(multiples)}")
151187

152188
def assign_clades(clade_designations, all_muts, tree, ref=None):
153189
'''
154190
Ensures all nodes have an entry (or auspice doesn't display nicely), tests each node
155-
to see if it's the first member of a clade (assigns 'clade_annotation'), and sets
156-
all nodes's clade_membership to the value of their parent. This will change if later found to be
191+
to see if it's the first member of a clade (this is the label), and sets the membership of each
192+
node to the value of their parent. This will change if later found to be
157193
the first member of a clade.
158194
159195
Parameters
@@ -169,16 +205,18 @@ def assign_clades(clade_designations, all_muts, tree, ref=None):
169205
170206
Returns
171207
-------
172-
dict
173-
mapping of node to clades
208+
(dict, dict)
209+
[0]: mapping of node to clade membership (where applicable)
210+
[1]: mapping of node to clade label (where applicable)
174211
'''
175212

176213
clade_membership = {}
214+
clade_labels = {}
177215
parents = get_parent_name_by_child_name_for_tree(tree)
178216

179217
# first pass to set all nodes to unassigned as precaution to ensure attribute is set
180218
for node in tree.find_clades(order = 'preorder'):
181-
clade_membership[node.name] = {'clade_membership': 'unassigned'}
219+
clade_membership[node.name] = UNASSIGNED
182220

183221
# count leaves
184222
for node in tree.find_clades(order = 'postorder'):
@@ -189,16 +227,16 @@ def assign_clades(clade_designations, all_muts, tree, ref=None):
189227
c.up=node
190228
tree.root.up = None
191229
tree.root.sequences = {'nuc':{}}
192-
tree.root.sequences.update({gene:{} for gene in all_muts[tree.root.name]['aa_muts']})
230+
tree.root.sequences.update({gene:{} for gene in all_muts.get(tree.root.name, {}).get('aa_muts', {})})
193231

194232
# attach sequences to all nodes
195233
for node in tree.find_clades(order='preorder'):
196234
if node.up:
197235
node.sequences = {gene:muts.copy() for gene, muts in node.up.sequences.items()}
198-
for mut in all_muts[node.name]['muts']:
236+
for mut in all_muts.get(node.name, {}).get('muts', []):
199237
a, pos, d = mut[0], int(mut[1:-1])-1, mut[-1]
200238
node.sequences['nuc'][pos] = d
201-
if 'aa_muts' in all_muts[node.name]:
239+
if 'aa_muts' in all_muts.get(node.name, {}):
202240
for gene in all_muts[node.name]['aa_muts']:
203241
for mut in all_muts[node.name]['aa_muts'][gene]:
204242
a, pos, d = mut[0], int(mut[1:-1])-1, mut[-1]
@@ -208,7 +246,7 @@ def assign_clades(clade_designations, all_muts, tree, ref=None):
208246
node.sequences[gene][pos] = d
209247

210248

211-
# second pass to assign 'clade_annotation' to basal nodes within each clade
249+
# second pass to assign basal nodes within each clade to the clade_labels dict
212250
# if multiple nodes match, assign annotation to largest
213251
# otherwise occasional unwanted cousin nodes get assigned the annotation
214252
for clade_name, clade_alleles in clade_designations.items():
@@ -219,58 +257,100 @@ def assign_clades(clade_designations, all_muts, tree, ref=None):
219257
sorted_nodes = sorted(node_counts, key=lambda x: x.leaf_count, reverse=True)
220258
if len(sorted_nodes) > 0:
221259
target_node = sorted_nodes[0]
222-
clade_membership[target_node.name] = {'clade_annotation': clade_name, 'clade_membership': clade_name}
260+
clade_membership[target_node.name] = clade_name
261+
clade_labels[target_node.name] = clade_name
223262

224-
# third pass to propagate 'clade_membership'
225-
# don't propagate if encountering 'clade_annotation'
263+
# third pass to propagate clade_membership to descendant nodes
264+
# (until we encounter a node with its own clade_membership)
226265
for node in tree.find_clades(order = 'preorder'):
227266
for child in node:
228-
if 'clade_annotation' not in clade_membership[child.name]:
229-
clade_membership[child.name]['clade_membership'] = clade_membership[node.name]['clade_membership']
267+
if child.name not in clade_labels:
268+
clade_membership[child.name] = clade_membership[node.name]
230269

231-
return clade_membership
270+
return (clade_membership, clade_labels)
271+
272+
def warn_if_clades_not_found(membership, clade_designations):
273+
clades = set(clade_designations.keys())
274+
found = set([clade for clade in membership.values() if clade!=UNASSIGNED])
275+
if not(len(found)):
276+
print(f"WARNING in augur.clades: no clades found in tree!")
277+
return
278+
for clade in clades-found:
279+
# warn loudly - one line per unfound clade
280+
print(f"WARNING in augur.clades: clade '{clade}' not found in tree!")
232281

233282

234283
def get_reference_sequence_from_root_node(all_muts, root_name):
235-
# attach sequences to root
284+
"""
285+
Extracts the (nuc) sequence from the root node, if set, as well as
286+
the (aa) sequences. Returns a dictionary of {geneName: rootSequence}
287+
where rootSequence is a list and geneName may be 'nuc'.
288+
"""
236289
ref = {}
237-
try:
238-
ref['nuc'] = list(all_muts[root_name]["sequence"])
239-
except:
240-
print("WARNING in augur.clades: nucleotide mutation json does not contain full sequences for the root node.")
241290

242-
if "aa_muts" in all_muts[root_name]:
291+
# the presence of a single mutation will imply that the corresponding reference
292+
# sequence should be present, and we will warn if it is not
293+
nt_present = False
294+
genes_present = set([])
295+
missing = []
296+
for d in all_muts.values():
297+
if "muts" in d:
298+
nt_present = True
299+
genes_present.update(d.get('aa_muts', {}).keys())
300+
301+
if nt_present:
302+
try:
303+
ref['nuc'] = list(all_muts.get(root_name, {})["sequence"])
304+
except KeyError:
305+
missing.append("nuc")
306+
307+
for gene in genes_present:
243308
try:
244-
ref.update({gene:list(seq) for gene, seq in all_muts[root_name]["aa_sequences"].items()})
245-
except:
246-
print("WARNING in augur.clades: amino acid mutation json does not contain full sequences for the root node.")
309+
ref[gene] = list(all_muts.get(root_name, {}).get("aa_sequences", {})[gene])
310+
except KeyError:
311+
missing.append(gene)
312+
313+
if missing:
314+
print(f"WARNING in augur.clades: sequences at the root node have not been specified for {{{', '.join(missing)}}}, \
315+
even though mutations were observed. Clades which are annotated using bases/codons present at the root \
316+
of the tree may not be correctly inferred.")
247317

248318
return ref
249319

320+
def parse_nodes(tree_file, node_data_files):
321+
tree = Phylo.read(tree_file, 'newick')
322+
# don't supply tree to read_node_data as we don't want to require that every node is present in the node_data JSONs
323+
node_data = read_node_data(node_data_files)
324+
# node_data files can be parsed without 'nodes' (if they have 'branches')
325+
if "nodes" not in node_data or len(node_data['nodes'].keys())==0:
326+
raise AugurError(f"No nodes found in the supplied node data files. Please check {', '.join(node_data_files)}")
327+
json_nodes = set(node_data["nodes"].keys())
328+
tree_nodes = set([clade.name for clade in tree.find_clades()])
329+
if not json_nodes.issubset(tree_nodes):
330+
raise AugurError(f"The following nodes in the node_data files ({', '.join(node_data_files)}) are not found in the tree ({tree_file}): {', '.join(json_nodes - tree_nodes)}")
331+
ensure_no_multiple_mutations(node_data['nodes'])
332+
return (tree, node_data['nodes'])
250333

251334
def register_parser(parent_subparsers):
252335
parser = parent_subparsers.add_parser("clades", help=__doc__)
253-
parser.add_argument('--tree', help="prebuilt Newick -- no tree will be built if provided")
254-
parser.add_argument('--mutations', nargs='+', help='JSON(s) containing ancestral and tip nucleotide and/or amino-acid mutations ')
255-
parser.add_argument('--reference', nargs='+', help='fasta files containing reference and tip nucleotide and/or amino-acid sequences ')
256-
parser.add_argument('--clades', type=str, help='TSV file containing clade definitions by amino-acid')
257-
parser.add_argument('--output-node-data', type=str, help='name of JSON file to save clade assignments to')
336+
parser.add_argument('--tree', required=True, help="prebuilt Newick -- no tree will be built if provided")
337+
parser.add_argument('--mutations', required=True, metavar="NODE_DATA_JSON", nargs='+', help='JSON(s) containing ancestral and tip nucleotide and/or amino-acid mutations ')
338+
parser.add_argument('--reference', nargs='+', help=SUPPRESS)
339+
parser.add_argument('--clades', required=True, metavar="TSV", type=str, help='TSV file containing clade definitions by amino-acid')
340+
parser.add_argument('--output-node-data', type=str, metavar="NODE_DATA_JSON", help='name of JSON file to save clade assignments to')
341+
parser.add_argument('--membership-name', type=str, default="clade_membership", help='Key to store clade membership under; use "None" to not export this')
342+
parser.add_argument('--label-name', type=str, default="clade", help='Key to store clade labels under; use "None" to not export this')
258343
return parser
259344

260345

261346
def run(args):
262-
## read tree and data, if reading data fails, return with error code
263-
tree = Phylo.read(args.tree, 'newick')
264-
node_data = read_node_data(args.mutations, args.tree)
265-
if node_data is None:
266-
print("ERROR: could not read node data (incl sequences)")
267-
return 1
268-
all_muts = node_data['nodes']
347+
(tree, all_muts) = parse_nodes(args.tree, args.mutations)
269348

270349
if args.reference:
271350
# PLACE HOLDER FOR vcf WORKFLOW.
272351
# Works without a reference for now but can be added if clade defs contain positions
273352
# that are monomorphic across reference and sequence sample.
353+
print(f"WARNING in augur.clades: You have provided a --reference file(s) ({args.reference}) however this is functionality is not yet supported.")
274354
ref = None
275355
else:
276356
# extract reference sequences from the root node entry in the mutation json
@@ -279,8 +359,20 @@ def run(args):
279359

280360
clade_designations = read_in_clade_definitions(args.clades)
281361

282-
clade_membership = assign_clades(clade_designations, all_muts, tree, ref)
362+
membership, labels = assign_clades(clade_designations, all_muts, tree, ref)
363+
warn_if_clades_not_found(membership, clade_designations)
364+
365+
membership_key= args.membership_name if args.membership_name.upper() != "NONE" else None
366+
label_key= args.label_name if args.label_name.upper() != "NONE" else None
367+
368+
node_data_json = {}
369+
if membership_key:
370+
node_data_json['nodes'] = {node: {membership_key: clade} for node,clade in membership.items()}
371+
print(f"Clade membership stored on nodes → <node_name> → {membership_key}", file=sys.stdout)
372+
if label_key:
373+
node_data_json['branches'] = {node: {'labels': {label_key: label}} for node,label in labels.items()}
374+
print(f"Clade labels stored on branches → <node_name> → labels → {label_key}", file=sys.stdout)
283375

284376
out_name = get_json_name(args)
285-
write_json({'nodes': clade_membership}, out_name)
286-
print("clades written to", out_name, file=sys.stdout)
377+
write_json(node_data_json, out_name)
378+
print(f"Clades written to {out_name}", file=sys.stdout)

0 commit comments

Comments
 (0)