Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update snippets to use multiple language variants #1718

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion core/snippets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ It's also possible to set configuration that applies to a specific tab stop (`$0

## Formatting and syntax highlighting

To get formatting and syntax highlighting for `.snippet` files install [andreas-talon](https://marketplace.visualstudio.com/items?itemName=AndreasArvidsson.andreas-talon)
To get formatting, code completion and syntax highlighting for `.snippet` files: install [andreas-talon](https://marketplace.visualstudio.com/items?itemName=AndreasArvidsson.andreas-talon)

## Examples

Expand Down
40 changes: 30 additions & 10 deletions core/snippets/snippet_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
from dataclasses import dataclass

from talon import Context


class SnippetLists:
insertion: dict[str, str]
with_phrase: dict[str, str]
wrapper: dict[str, str]

def __init__(self):
self.insertion = {}
self.with_phrase = {}
self.wrapper = {}


@dataclass
class SnippetLanguageState:
ctx: Context
lists: SnippetLists


@dataclass
class SnippetVariable:
Expand All @@ -14,16 +33,15 @@ class Snippet:
name: str
body: str
description: str | None
phrases: list[str] | None = None
insertion_scopes: list[str] | None = None
languages: list[str] | None = None
variables: list[SnippetVariable] | None = None
phrases: list[str] | None
insertion_scopes: list[str] | None
languages: list[str] | None
variables: list[SnippetVariable]

def get_variable(self, name: str):
if self.variables:
for var in self.variables:
if var.name == name:
return var
for var in self.variables:
if var.name == name:
return var
return None

def get_variable_strict(self, name: str):
Expand All @@ -36,11 +54,13 @@ def get_variable_strict(self, name: str):
@dataclass
class InsertionSnippet:
body: str
scopes: list[str] | None = None
scopes: list[str] | None
languages: list[str] | None


@dataclass
class WrapperSnippet:
body: str
variable_name: str
scope: str | None = None
scope: str | None
languages: list[str] | None
240 changes: 139 additions & 101 deletions core/snippets/snippets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import glob
from collections import defaultdict
from pathlib import Path
from typing import Union

from talon import Context, Module, actions, app, fs, settings
from talon import Context, Module, actions, app, fs

from ..modes.code_languages import code_languages
from .snippet_types import InsertionSnippet, Snippet, WrapperSnippet
from .snippet_types import (
InsertionSnippet,
Snippet,
SnippetLanguageState,
SnippetLists,
WrapperSnippet,
)
from .snippets_parser import create_snippets_from_file

SNIPPETS_DIR = Path(__file__).parent / "snippets"
Expand All @@ -23,17 +29,22 @@
desc="Directory (relative to Talon user) containing additional snippets",
)

context_map = {
# `_` represents the global context, ie snippets available regardless of language
"_": Context(),
# `_` represents the global context, ie snippets available regardless of language
GLOBAL_ID = "_"

# { SNIPPET_NAME: Snippet[] }
snippets_map: dict[str, list[Snippet]] = {}

# { LANGUAGE_ID: SnippetLanguageState }
languages_state_map: dict[str, SnippetLanguageState] = {
GLOBAL_ID: SnippetLanguageState(Context(), SnippetLists())
}
snippets_map = {}

# Create a context for each defined language
for lang in code_languages:
ctx = Context()
ctx.matches = f"code.language: {lang.id}"
context_map[lang.id] = ctx
languages_state_map[lang.id] = SnippetLanguageState(ctx, SnippetLists())


def get_setting_dir():
Expand All @@ -52,128 +63,155 @@ def get_setting_dir():

@mod.action_class
class Actions:
def get_snippet(name: str) -> Snippet:
"""Get snippet named <name>"""
# Add current code language if not specified
if "." not in name:
lang = actions.code.language() or "_"
name = f"{lang}.{name}"

def get_snippets(name: str) -> list[Snippet]:
"""Get snippets named <name>"""
if name not in snippets_map:
raise ValueError(f"Unknown snippet '{name}'")

return snippets_map[name]

def get_snippet(name: str) -> Snippet:
"""Get snippet named <name> for the active language"""
snippets: list[Snippet] = actions.user.get_snippets(name)
return get_preferred_snippet(snippets)

def get_insertion_snippets(name: str) -> list[InsertionSnippet]:
"""Get insertion snippets named <name>"""
snippets: list[Snippet] = actions.user.get_snippets(name)
return [
InsertionSnippet(s.body, s.insertion_scopes, s.languages) for s in snippets
]

def get_insertion_snippet(name: str) -> InsertionSnippet:
"""Get insertion snippet named <name>"""
"""Get insertion snippet named <name> for the active language"""
snippet: Snippet = actions.user.get_snippet(name)
return InsertionSnippet(snippet.body, snippet.insertion_scopes)
return InsertionSnippet(
snippet.body,
snippet.insertion_scopes,
snippet.languages,
)

def get_wrapper_snippets(name: str) -> list[WrapperSnippet]:
"""Get wrapper snippets named <name>"""
snippet_name, variable_name = split_wrapper_snippet_name(name)
snippets: list[Snippet] = actions.user.get_snippets(snippet_name)
return [to_wrapper_snippet(s, variable_name) for s in snippets]

def get_wrapper_snippet(name: str) -> WrapperSnippet:
"""Get wrapper snippet named <name>"""
index = name.rindex(".")
snippet_name = name[:index]
variable_name = name[index + 1]
"""Get wrapper snippet named <name> for the active language"""
snippet_name, variable_name = split_wrapper_snippet_name(name)
snippet: Snippet = actions.user.get_snippet(snippet_name)
variable = snippet.get_variable_strict(variable_name)
return WrapperSnippet(snippet.body, variable.name, variable.wrapper_scope)
return to_wrapper_snippet(snippet, variable_name)


def update_snippets():
language_to_snippets = group_by_language(get_snippets())

snippets_map.clear()

for lang, ctx in context_map.items():
insertion_map = {}
insertions_phrase_map = {}
wrapper_map = {}

# Assign global snippets to all languages
for lang_super in ["_", lang]:
snippets, insertions, insertions_phrase, wrappers = create_lists(
lang,
lang_super,
language_to_snippets.get(lang_super, []),
)
snippets_map.update(snippets)
insertion_map.update(insertions)
insertions_phrase_map.update(insertions_phrase)
wrapper_map.update(wrappers)

ctx.lists.update(
{
"user.snippet": insertion_map,
"user.snippet_with_phrase": insertions_phrase_map,
"user.snippet_wrapper": wrapper_map,
}
)
def get_preferred_snippet(snippets: list[Snippet]) -> Snippet:
lang: Union[str, set[str]] = actions.code.language()
languages = [lang] if isinstance(lang, str) else lang

# First try to find a snippet matching the active language
for snippet in snippets:
if snippet.languages:
for snippet_lang in snippet.languages:
if snippet_lang in languages:
return snippet

def get_snippets() -> list[Snippet]:
files = glob.glob(f"{SNIPPETS_DIR}/**/*.snippet", recursive=True)
# Then look for a global snippet
for snippet in snippets:
if not snippet.languages:
return snippet

if get_setting_dir():
files.extend(glob.glob(f"{get_setting_dir()}/**/*.snippet", recursive=True))
raise ValueError(f"Snippet not available for language '{lang}'")

result = []

for file in files:
result.extend(create_snippets_from_file(file))
def split_wrapper_snippet_name(name: str) -> tuple[str, str]:
index = name.rindex(".")
return name[:index], name[index + 1]

return result

def to_wrapper_snippet(snippet: Snippet, variable_name) -> WrapperSnippet:
"""Get wrapper snippet named <name>"""
var = snippet.get_variable_strict(variable_name)
return WrapperSnippet(
snippet.body,
var.name,
var.wrapper_scope,
snippet.languages,
)


def group_by_language(snippets: list[Snippet]) -> dict[str, list[Snippet]]:
result = defaultdict(list)
def update_snippets():
global snippets_map

snippets = get_snippets_from_files()
name_to_snippets: dict[str, list[Snippet]] = {}
language_to_lists: dict[str, SnippetLists] = {}

for snippet in snippets:
if snippet.languages is not None:
for lang in snippet.languages:
result[lang].append(snippet)
else:
result["_"].append(snippet)
return result
# Map snippet names to actual snippets
name_to_snippets.setdefault(snippet.name, [])
name_to_snippets[snippet.name].append(snippet)
AndreasArvidsson marked this conversation as resolved.
Show resolved Hide resolved

# Map languages to phrase / name dicts
for language in snippet.languages or [GLOBAL_ID]:
language_to_lists.setdefault(language, SnippetLists())
lists = language_to_lists[language]
AndreasArvidsson marked this conversation as resolved.
Show resolved Hide resolved

def create_lists(
lang_ctx: str,
lang_snippets: str,
snippets: list[Snippet],
) -> tuple[dict[str, list[Snippet]], dict[str, str], dict[str, str], dict[str, str]]:
"""Creates the lists for the given language, and returns them as a tuple of (snippets, insertions, insertions_phrase, wrappers)
for phrase in snippet.phrases or []:
lists.insertion[phrase] = snippet.name

Args:
lang_ctx (str): The language of the context match
lang_snippets (str): The language of the snippets
snippets (list[Snippet]): The list of snippets for the given language
"""
snippets_map = {}
insertions = {}
insertions_phrase = {}
wrappers = {}
for var in snippet.variables:
if var.insertion_formatters:
lists.with_phrase[phrase] = snippet.name

for snippet in snippets:
id_ctx = f"{lang_ctx}.{snippet.name}"
id_lang = f"{lang_snippets}.{snippet.name}"
if var.wrapper_phrases:
lists.wrapper[phrase] = f"{snippet.name}.{var.name}"

snippets_map = name_to_snippets
update_contexts(language_to_lists)


def update_contexts(language_to_lists: dict[str, SnippetLists]):
global_lists = language_to_lists[GLOBAL_ID] or SnippetLists()

for lang, lists in language_to_lists.items():
if lang not in languages_state_map:
print(f"Found snippets for unknown language: {lang}")
AndreasArvidsson marked this conversation as resolved.
Show resolved Hide resolved
actions.app.notify(f"Found snippets for unknown language: {lang}")
continue

state = languages_state_map[lang]
insertion = {**global_lists.insertion, **lists.insertion}
with_phrase = {**global_lists.with_phrase, **lists.with_phrase}
wrapper = {**global_lists.wrapper, **lists.wrapper}
updated_lists: dict[str, dict[str, str]] = {}

if state.lists.insertion != insertion:
state.lists.insertion = insertion
updated_lists["user.snippet"] = insertion

# Make sure that the snippet is added to the map for the context language
snippets_map[id_ctx] = snippet
if state.lists.with_phrase != with_phrase:
state.lists.with_phrase = with_phrase
updated_lists["user.snippet_with_phrase"] = with_phrase

if snippet.phrases is not None:
for phrase in snippet.phrases:
insertions[phrase] = id_lang
if state.lists.wrapper != wrapper:
state.lists.wrapper = wrapper
updated_lists["user.snippet_wrapper"] = wrapper

if snippet.variables is not None:
for var in snippet.variables:
if var.insertion_formatters is not None and snippet.phrases is not None:
for phrase in snippet.phrases:
insertions_phrase[phrase] = id_lang
if updated_lists:
state.ctx.lists.update(updated_lists)

if var.wrapper_phrases is not None:
for phrase in var.wrapper_phrases:
wrappers[phrase] = f"{id_lang}.{var.name}"

return snippets_map, insertions, insertions_phrase, wrappers
def get_snippets_from_files() -> list[Snippet]:
files = glob.glob(f"{SNIPPETS_DIR}/**/*.snippet", recursive=True)

if get_setting_dir():
files.extend(glob.glob(f"{get_setting_dir()}/**/*.snippet", recursive=True))
nriley marked this conversation as resolved.
Show resolved Hide resolved

result = []

for file in files:
result.extend(create_snippets_from_file(file))

return result


def on_ready():
Expand Down
5 changes: 4 additions & 1 deletion core/snippets/snippets_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def insert_snippet(body: str):
"""Insert snippet"""
insert_snippet_raw_text(body)

def insert_snippet_by_name(name: str, substitutions: dict[str, str] = None):
def insert_snippet_by_name(
name: str,
substitutions: dict[str, str] = None,
):
"""Insert snippet <name>"""
snippet: Snippet = actions.user.get_snippet(name)
body = snippet.body
Expand Down