diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index bec6658..73e429d 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -10,7 +10,7 @@ Please: - [ ] Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this: ```python -import brainscale +import braintrace ``` - [ ] If applicable, include full error messages/tracebacks. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index d48ac9b..5cf95c4 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,6 +1,6 @@ --- name: 'Feature Request' -about: 'Suggest a new idea or improvement for BrainScale' +about: 'Suggest a new idea or improvement for BrainTrace' labels: 'enhancement' --- diff --git a/.github/workflows/CI-daily.yml b/.github/workflows/CI-daily.yml index de4b13d..83408cf 100644 --- a/.github/workflows/CI-daily.yml +++ b/.github/workflows/CI-daily.yml @@ -69,5 +69,5 @@ jobs: python -m pip install . || exit 1 - name: Test with pytest run: | - pytest brainscale/ + pytest braintrace/ diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index fcc983d..bd4e567 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,7 +50,7 @@ jobs: pip install . - name: Test with pytest run: | - pytest brainscale/ + pytest braintrace/ test_macos: @@ -80,7 +80,7 @@ jobs: pip install . - name: Test with pytest run: | - pytest brainscale/ + pytest braintrace/ test_windows: @@ -110,5 +110,5 @@ jobs: pip install . - name: Test with pytest run: | - pytest brainscale/ -p no:faulthandler + pytest braintrace/ -p no:faulthandler diff --git a/.github/workflows/Publish.yml b/.github/workflows/Publish.yml index aeb4e23..022cdfe 100644 --- a/.github/workflows/Publish.yml +++ b/.github/workflows/Publish.yml @@ -39,7 +39,7 @@ jobs: echo "Release tag: $TAG" PKG_VERSION=$(python - <<'PY' import re, pathlib - init = pathlib.Path('brainscale/__init__.py').read_text(encoding='utf-8') + init = pathlib.Path('braintrace/__init__.py').read_text(encoding='utf-8') m = re.search(r"__version__\s*=\s*\"([^\"]+)\"", init) print(m.group(1) if m else '') PY @@ -47,7 +47,7 @@ jobs: echo "Package version: ${PKG_VERSION}" TAG_STRIPPED=${TAG#v} if [[ -z "$PKG_VERSION" ]]; then - echo "Could not determine package version from brainscale/__init__.py" >&2 + echo "Could not determine package version from braintrace/__init__.py" >&2 exit 1 fi if [[ "$TAG_STRIPPED" != "$PKG_VERSION" ]]; then diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 09b4ed7..fcf269f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ -# Contributing to BrainScale +# Contributing to BrainTrace -We welcome contributions to BrainScale! +We welcome contributions to BrainTrace! Please fork this repository and submit a pull request with your changes. We will review your changes and merge them if they are appropriate. diff --git a/README.md b/README.md index 6006946..02b88c6 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,30 @@ -

BrainScale

-

Scalable Online Learning for Brain Dynamics

+

BrainTrace

+

SEligibility Trace-based Online Learning for Brain Dynamics

- Header image of brainscale. + Header image of braintrace.

- Supported Python Version - LICENSE - Documentation - PyPI version - Continuous Integration + Supported Python Version + LICENSE + Documentation + PyPI version + Continuous Integration

-[``brainscale``](https://github.com/chaobrain/brainscale) provides online learning algorithms for biological neural networks. +[``braintrace``](https://github.com/chaobrain/braintrace) provides online learning algorithms for biological neural networks. It has been integrated into our establishing [brain modeling ecosystem](https://brainmodeling.readthedocs.io/). ## Installation -``brainscale`` can run on Python 3.10+ installed on Linux, MacOS, and Windows. You can install ``brainscale`` via pip: +``braintrace`` can run on Python 3.10+ installed on Linux, MacOS, and Windows. You can install ``braintrace`` via pip: ```bash -pip install brainscale --upgrade +pip install braintrace --upgrade ``` -Alternatively, you can install `BrainX`, which bundles `brainscale` with other compatible packages for a comprehensive brain modeling ecosystem: +Alternatively, you can install `BrainX`, which bundles `braintrace` with other compatible packages for a comprehensive brain modeling ecosystem: ```bash pip install BrainX -U @@ -32,26 +32,40 @@ pip install BrainX -U ## Documentation -The official documentation is hosted on Read the Docs: [https://brainscale.readthedocs.io](https://brainscale.readthedocs.io) +The official documentation is hosted on Read the Docs: [https://braintrace.readthedocs.io](https://braintrace.readthedocs.io) -## Citation +[//]: # (## Citation) -If you use this package in your research, please cite: +[//]: # () +[//]: # (If you use this package in your research, please cite:) -```bibtex -@article {Wang2024.09.24.614728, - author = {Wang, Chaoming and Dong, Xingsi and Ji, Zilong and Jiang, Jiedong and Liu, Xiao and Wu, Si}, - title = {BrainScale: Enabling Scalable Online Learning in Spiking Neural Networks}, - elocation-id = {2024.09.24.614728}, - year = {2025}, - doi = {10.1101/2024.09.24.614728}, - publisher = {Cold Spring Harbor Laboratory}, - URL = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728}, - eprint = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728.full.pdf}, - journal = {bioRxiv} -} -``` +[//]: # () +[//]: # (```bibtex) + +[//]: # (@article {Wang2024.09.24.614728,) + +[//]: # ( author = {Wang, Chaoming and Dong, Xingsi and Ji, Zilong and Jiang, Jiedong and Liu, Xiao and Wu, Si},) + +[//]: # ( title = {Enabling Scalable Online Learning in Spiking Neural Networks},) + +[//]: # ( elocation-id = {2024.09.24.614728},) + +[//]: # ( year = {2025},) + +[//]: # ( doi = {10.1101/2024.09.24.614728},) + +[//]: # ( publisher = {Cold Spring Harbor Laboratory},) + +[//]: # ( URL = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728},) + +[//]: # ( eprint = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728.full.pdf},) + +[//]: # ( journal = {bioRxiv}) + +[//]: # (}) + +[//]: # (```) ## See also the ecosystem -``brainscale`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/ +``braintrace`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/ diff --git a/brainscale/__init__.py b/braintrace/__init__.py similarity index 54% rename from brainscale/__init__.py rename to braintrace/__init__.py index 5b32ec7..801bf3f 100644 --- a/brainscale/__init__.py +++ b/braintrace/__init__.py @@ -19,32 +19,32 @@ __version__ = "0.1.0" __version_info__ = (0, 1, 0) -from brainscale._etrace_algorithms import * -from brainscale._etrace_algorithms import __all__ as _alg_all -from brainscale._etrace_compiler_graph import * -from brainscale._etrace_compiler_graph import __all__ as _compiler_all -from brainscale._etrace_compiler_hid_param_op import * -from brainscale._etrace_compiler_hid_param_op import __all__ as _hid_param_all -from brainscale._etrace_compiler_hidden_group import * -from brainscale._etrace_compiler_hidden_group import __all__ as _hid_group_all -from brainscale._etrace_compiler_hidden_pertubation import * -from brainscale._etrace_compiler_hidden_pertubation import __all__ as _hid_pertub_all -from brainscale._etrace_compiler_module_info import * -from brainscale._etrace_compiler_module_info import __all__ as _mod_info_all -from brainscale._etrace_concepts import * -from brainscale._etrace_concepts import __all__ as _con_all -from brainscale._etrace_graph_executor import * -from brainscale._etrace_graph_executor import __all__ as _exec_all -from brainscale._etrace_input_data import * -from brainscale._etrace_input_data import __all__ as _data_all -from brainscale._etrace_operators import * -from brainscale._etrace_operators import __all__ as _op_all -from brainscale._etrace_vjp import * -from brainscale._etrace_vjp import __all__ as _vjp_all -from brainscale._grad_exponential import * -from brainscale._grad_exponential import __all__ as _grad_exp_all -from brainscale._misc import * -from brainscale._misc import __all__ as _misc_all +from braintrace._etrace_algorithms import * +from braintrace._etrace_algorithms import __all__ as _alg_all +from braintrace._etrace_compiler_graph import * +from braintrace._etrace_compiler_graph import __all__ as _compiler_all +from braintrace._etrace_compiler_hid_param_op import * +from braintrace._etrace_compiler_hid_param_op import __all__ as _hid_param_all +from braintrace._etrace_compiler_hidden_group import * +from braintrace._etrace_compiler_hidden_group import __all__ as _hid_group_all +from braintrace._etrace_compiler_hidden_pertubation import * +from braintrace._etrace_compiler_hidden_pertubation import __all__ as _hid_pertub_all +from braintrace._etrace_compiler_module_info import * +from braintrace._etrace_compiler_module_info import __all__ as _mod_info_all +from braintrace._etrace_concepts import * +from braintrace._etrace_concepts import __all__ as _con_all +from braintrace._etrace_graph_executor import * +from braintrace._etrace_graph_executor import __all__ as _exec_all +from braintrace._etrace_input_data import * +from braintrace._etrace_input_data import __all__ as _data_all +from braintrace._etrace_operators import * +from braintrace._etrace_operators import __all__ as _op_all +from braintrace._etrace_vjp import * +from braintrace._etrace_vjp import __all__ as _vjp_all +from braintrace._grad_exponential import * +from braintrace._grad_exponential import __all__ as _grad_exp_all +from braintrace._misc import * +from braintrace._misc import __all__ as _misc_all from . import nn __all__ = ['nn'] + _alg_all + _compiler_all + _hid_param_all + _hid_group_all + _hid_pertub_all @@ -69,7 +69,7 @@ def __getattr__(name): import brainstate warnings.warn( - f"brainscale.{name} is deprecated and will be removed in a future release. " + f"braintrace.{name} is deprecated and will be removed in a future release. " f"Please use brainstate.{mapping[name]} instead.", DeprecationWarning, stacklevel=2, diff --git a/brainscale/_compatible_imports.py b/braintrace/_compatible_imports.py similarity index 100% rename from brainscale/_compatible_imports.py rename to braintrace/_compatible_imports.py diff --git a/brainscale/_compatible_imports_test.py b/braintrace/_compatible_imports_test.py similarity index 98% rename from brainscale/_compatible_imports_test.py rename to braintrace/_compatible_imports_test.py index 74f0bae..74eb9d6 100644 --- a/brainscale/_compatible_imports_test.py +++ b/braintrace/_compatible_imports_test.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from jax import jit, make_jaxpr, lax -from brainscale._compatible_imports import ( +from braintrace._compatible_imports import ( is_jit_primitive, is_scan_primitive, is_while_primitive, is_cond_primitive ) diff --git a/brainscale/_etrace_algorithms.py b/braintrace/_etrace_algorithms.py similarity index 98% rename from brainscale/_etrace_algorithms.py rename to braintrace/_etrace_algorithms.py index 231b79c..83d029a 100644 --- a/brainscale/_etrace_algorithms.py +++ b/braintrace/_etrace_algorithms.py @@ -40,7 +40,7 @@ class EligibilityTrace(brainstate.ShortTermState): Examples -------- - When you are using :class:`brainscale.IODimVjpAlgorithm`, you can get + When you are using :class:`braintrace.IODimVjpAlgorithm`, you can get the eligibility trace of the weight by calling: .. code-block:: python @@ -48,7 +48,7 @@ class EligibilityTrace(brainstate.ShortTermState): >>> etrace = etrace_algorithm.etrace_of(weight) """ - __module__ = 'brainscale' + __module__ = 'braintrace' class ETraceAlgorithm(brainstate.nn.Module): @@ -77,7 +77,7 @@ class ETraceAlgorithm(brainstate.nn.Module): running_index : brainstate.ParamState[int] The running index. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_compiler_base.py b/braintrace/_etrace_compiler_base.py similarity index 99% rename from brainscale/_etrace_compiler_base.py rename to braintrace/_etrace_compiler_base.py index c60979f..3318b28 100644 --- a/brainscale/_etrace_compiler_base.py +++ b/braintrace/_etrace_compiler_base.py @@ -168,7 +168,7 @@ class JaxprEvaluation(object): outvar_to_hidden_path : Dict[Var, Path] Stored mapping from output variables to hidden paths. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_compiler_graph.py b/braintrace/_etrace_compiler_graph.py similarity index 98% rename from brainscale/_etrace_compiler_graph.py rename to braintrace/_etrace_compiler_graph.py index 9211aee..eb0b0fb 100644 --- a/brainscale/_etrace_compiler_graph.py +++ b/braintrace/_etrace_compiler_graph.py @@ -86,12 +86,12 @@ class ETraceGraph(NamedTuple): Example:: - >>> import brainscale + >>> import braintrace >>> import brainstate - >>> gru = brainscale.nn.GRUCell(10, 20) + >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) - >>> compiled_graph = brainscale.compile_etrace_graph(gru, inputs) + >>> compiled_graph = braintrace.compile_etrace_graph(gru, inputs) >>> compiled_graph.dict().keys() """ @@ -127,7 +127,7 @@ def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) -ETraceGraph.__module__ = 'brainscale' +ETraceGraph.__module__ = 'braintrace' class CONTEXT(threading.local): diff --git a/brainscale/_etrace_compiler_graph_test.py b/braintrace/_etrace_compiler_graph_test.py similarity index 84% rename from brainscale/_etrace_compiler_graph_test.py rename to braintrace/_etrace_compiler_graph_test.py index f80617f..ee7716c 100644 --- a/brainscale/_etrace_compiler_graph_test.py +++ b/braintrace/_etrace_compiler_graph_test.py @@ -20,8 +20,8 @@ import brainstate import brainunit as u -import brainscale -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -40,13 +40,13 @@ def test_gru_one_layer(self): n_in = 3 n_out = 4 - gru = brainscale.nn.GRUCell(n_in, n_out) + gru = braintrace.nn.GRUCell(n_in, n_out) brainstate.nn.init_all_states(gru) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(gru, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(gru, input, include_hidden_perturb=False) - self.assertTrue(isinstance(graph, brainscale.ETraceGraph)) + self.assertTrue(isinstance(graph, braintrace.ETraceGraph)) self.assertTrue(graph.module_info.num_var_out == 1) self.assertTrue(len(graph.module_info.compiled_model_states) == 4) self.assertTrue(len(graph.hidden_groups) == 1) @@ -61,11 +61,11 @@ def test_lru_one_layer(self): n_in = 3 n_out = 4 - lru = brainscale.nn.LRUCell(n_in, n_out) + lru = braintrace.nn.LRUCell(n_in, n_out) brainstate.nn.init_all_states(lru) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(lru, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(lru, input, include_hidden_perturb=False) self.assertTrue(len(graph.hidden_groups) == 1) self.assertTrue(len(graph.hidden_groups[0].hidden_paths) == 2) @@ -80,13 +80,13 @@ def test_lstm_one_layer(self): n_in = 3 n_out = 4 - lstm = brainscale.nn.LSTMCell(n_in, n_out) + lstm = braintrace.nn.LSTMCell(n_in, n_out) brainstate.nn.init_all_states(lstm) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(lstm, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(lstm, input, include_hidden_perturb=False) - self.assertTrue(isinstance(graph, brainscale.ETraceGraph)) + self.assertTrue(isinstance(graph, braintrace.ETraceGraph)) self.assertTrue(graph.module_info.num_var_out == 1) self.assertTrue(len(graph.hidden_groups) == 1) self.assertTrue(len(graph.hidden_groups[0].hidden_paths) == 2) @@ -112,16 +112,16 @@ def test_lstm_two_layers(self): n_out = 4 net = brainstate.nn.Sequential( - brainscale.nn.LSTMCell(n_in, n_out), + braintrace.nn.LSTMCell(n_in, n_out), brainstate.nn.ReLU(), - brainscale.nn.LSTMCell(n_out, n_in), + braintrace.nn.LSTMCell(n_out, n_in), ) brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) - self.assertTrue(isinstance(graph, brainscale.ETraceGraph)) + self.assertTrue(isinstance(graph, braintrace.ETraceGraph)) self.assertTrue(graph.module_info.num_var_out == 1) self.assertTrue(len(graph.hidden_groups) == 2) self.assertTrue(len(graph.hidden_groups[0].hidden_paths) == 2) @@ -145,14 +145,14 @@ def test_lru_two_layers(self): n_out = 4 net = brainstate.nn.Sequential( - brainscale.nn.LRUCell(n_in, n_out), + braintrace.nn.LRUCell(n_in, n_out), brainstate.nn.ReLU(), - brainscale.nn.LRUCell(n_in, n_out), + braintrace.nn.LRUCell(n_in, n_out), ) brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) self.assertTrue(len(graph.hidden_groups) == 2) self.assertTrue(len(graph.hidden_groups[0].hidden_paths) == 2) @@ -173,14 +173,14 @@ def test_lru_two_layers_v2(self): n_out = 4 net = brainstate.nn.Sequential( - brainscale.nn.LRUCell(n_in, n_out), + braintrace.nn.LRUCell(n_in, n_out), brainstate.nn.ReLU(), - brainscale.nn.LRUCell(n_in, n_out), + braintrace.nn.LRUCell(n_in, n_out), ) brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) self.assertTrue(len(graph.hidden_groups) == 2) self.assertTrue(len(graph.hidden_groups[0].hidden_paths) == 2) @@ -211,7 +211,7 @@ def test_if_delta_dense(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) pass @@ -224,7 +224,7 @@ def test_lif_expco_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -236,7 +236,7 @@ def test_alif_expco_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -248,7 +248,7 @@ def test_lif_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -260,7 +260,7 @@ def test_lif_std_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -272,7 +272,7 @@ def test_lif_stp_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -284,7 +284,7 @@ def test_alif_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -296,7 +296,7 @@ def test_alif_delta_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -308,7 +308,7 @@ def test_alif_std_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) @@ -320,7 +320,7 @@ def test_alif_stp_expcu_dense_layer(self): brainstate.nn.init_all_states(net) input = brainstate.random.rand(n_in) - graph = brainscale.compile_etrace_graph(net, input, include_hidden_perturb=False) + graph = braintrace.compile_etrace_graph(net, input, include_hidden_perturb=False) pprint(graph) diff --git a/brainscale/_etrace_compiler_hid_param_op.py b/braintrace/_etrace_compiler_hid_param_op.py similarity index 99% rename from brainscale/_etrace_compiler_hid_param_op.py rename to braintrace/_etrace_compiler_hid_param_op.py index c42b06c..92892c0 100644 --- a/brainscale/_etrace_compiler_hid_param_op.py +++ b/braintrace/_etrace_compiler_hid_param_op.py @@ -116,12 +116,12 @@ class HiddenParamOpRelation(NamedTuple): Example:: - >>> import brainscale + >>> import braintrace >>> import brainstate - >>> gru = brainscale.nn.GRUCell(10, 20) + >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) - >>> hpo_relations = brainscale.find_hidden_param_op_relations_from_module(gru, inputs) + >>> hpo_relations = braintrace.find_hidden_param_op_relations_from_module(gru, inputs) >>> for relation in hpo_relations: ... print(relation) """ @@ -172,7 +172,7 @@ def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) -HiddenParamOpRelation.__module__ = 'brainscale' +HiddenParamOpRelation.__module__ = 'braintrace' def _trace_simplify( @@ -441,7 +441,7 @@ class JaxprEvalForWeightOpHiddenRelation(JaxprEvaluation): Returns: The list of the traced weight operations. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_compiler_hid_param_op_test.py b/braintrace/_etrace_compiler_hid_param_op_test.py similarity index 94% rename from brainscale/_etrace_compiler_hid_param_op_test.py rename to braintrace/_etrace_compiler_hid_param_op_test.py index 6adc21a..04d2869 100644 --- a/brainscale/_etrace_compiler_hid_param_op_test.py +++ b/braintrace/_etrace_compiler_hid_param_op_test.py @@ -20,9 +20,9 @@ import brainunit as u import pytest -import brainscale -from brainscale import find_hidden_param_op_relations_from_module -from brainscale._etrace_model_test import ( +import braintrace +from braintrace import find_hidden_param_op_relations_from_module +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -41,7 +41,7 @@ def test_gru_one_layer(self): n_in = 3 n_out = 4 - gru = brainscale.nn.GRUCell(n_in, n_out) + gru = braintrace.nn.GRUCell(n_in, n_out) brainstate.nn.init_all_states(gru) input = brainstate.random.rand(n_in) diff --git a/brainscale/_etrace_compiler_hidden_group.py b/braintrace/_etrace_compiler_hidden_group.py similarity index 99% rename from brainscale/_etrace_compiler_hidden_group.py rename to braintrace/_etrace_compiler_hidden_group.py index 059a98b..9338654 100644 --- a/brainscale/_etrace_compiler_hidden_group.py +++ b/braintrace/_etrace_compiler_hidden_group.py @@ -88,12 +88,12 @@ class HiddenGroup(NamedTuple): Example:: - >>> import brainscale + >>> import braintrace >>> import brainstate - >>> gru = brainscale.nn.GRUCell(10, 20) + >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) - >>> hidden_groups, _ = brainscale.find_hidden_groups_from_module(gru, inputs) + >>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs) >>> for group in hidden_groups: ... print(group.hidden_paths) """ @@ -244,7 +244,7 @@ def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) -HiddenGroup.__module__ = 'brainscale' +HiddenGroup.__module__ = 'braintrace' def jacrev_last_dim( @@ -491,7 +491,7 @@ class JaxprEvalForHiddenGroup(JaxprEvaluation): outvar_to_hidden_path: The mapping from the hidden output variable to the hidden state path. path_to_state: The mapping from the hidden state path to the state. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_compiler_hidden_group_test.py b/braintrace/_etrace_compiler_hidden_group_test.py similarity index 95% rename from brainscale/_etrace_compiler_hidden_group_test.py rename to braintrace/_etrace_compiler_hidden_group_test.py index 7408d07..0515b59 100644 --- a/brainscale/_etrace_compiler_hidden_group_test.py +++ b/braintrace/_etrace_compiler_hidden_group_test.py @@ -23,11 +23,11 @@ import numpy as np import pytest -import brainscale -from brainscale import _etrace_model_with_group_state as group_etrace_model -from brainscale._etrace_compiler_hidden_group import find_hidden_groups_from_module -from brainscale._etrace_compiler_hidden_group import group_merging -from brainscale._etrace_model_test import ( +import braintrace +from braintrace import _etrace_model_with_group_state as group_etrace_model +from braintrace._etrace_compiler_hidden_group import find_hidden_groups_from_module +from braintrace._etrace_compiler_hidden_group import group_merging +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -122,11 +122,11 @@ class Test_find_hidden_groups_from_module: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_gru_one_layer(self, cls): @@ -277,8 +277,8 @@ def test_snn_single_layer_state_transition( brainstate.nn.init_all_states(layer_without_group) brainstate.nn.init_all_states(layer_with_group) - graph_without_group = brainscale.compile_etrace_graph(layer_without_group, input) - graph_with_group = brainscale.compile_etrace_graph(layer_with_group, input) + graph_without_group = braintrace.compile_etrace_graph(layer_without_group, input) + graph_with_group = braintrace.compile_etrace_graph(layer_with_group, input) out1, etrace1, other1, temp1 = graph_with_group.module_info.jaxpr_call(input) out2, etrace2, other2, temp2 = graph_without_group.module_info.jaxpr_call(input) @@ -345,8 +345,8 @@ def test_snn_two_layer_state_transition( brainstate.nn.init_all_states(layer_without_group) brainstate.nn.init_all_states(layer_with_group) - graph_without_group = brainscale.compile_etrace_graph(layer_without_group, input) - graph_with_group = brainscale.compile_etrace_graph(layer_with_group, input) + graph_without_group = braintrace.compile_etrace_graph(layer_without_group, input) + graph_with_group = braintrace.compile_etrace_graph(layer_with_group, input) out1, etrace1, other1, temp1 = graph_with_group.module_info.jaxpr_call(input) out2, etrace2, other2, temp2 = graph_without_group.module_info.jaxpr_call(input) @@ -414,8 +414,8 @@ def test_snn_single_layer_diagonal_jacobian( brainstate.nn.init_all_states(layer_without_group) brainstate.nn.init_all_states(layer_with_group) - graph_without_group = brainscale.compile_etrace_graph(layer_without_group, input) - graph_with_group = brainscale.compile_etrace_graph(layer_with_group, input) + graph_without_group = braintrace.compile_etrace_graph(layer_without_group, input) + graph_with_group = braintrace.compile_etrace_graph(layer_with_group, input) out1, etrace1, other1, temp1 = graph_with_group.module_info.jaxpr_call(input) out2, etrace2, other2, temp2 = graph_without_group.module_info.jaxpr_call(input) @@ -482,8 +482,8 @@ def test_snn_two_layer_diagonal_jacobian( brainstate.nn.init_all_states(layer_without_group) brainstate.nn.init_all_states(layer_with_group) - graph_without_group = brainscale.compile_etrace_graph(layer_without_group, input) - graph_with_group = brainscale.compile_etrace_graph(layer_with_group, input) + graph_without_group = braintrace.compile_etrace_graph(layer_without_group, input) + graph_with_group = braintrace.compile_etrace_graph(layer_with_group, input) out1, etrace1, other1, temp1 = graph_with_group.module_info.jaxpr_call(input) out2, etrace2, other2, temp2 = graph_without_group.module_info.jaxpr_call(input) @@ -530,11 +530,11 @@ class TestHiddenGroup_state_transition: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_gru(self, cls): @@ -629,11 +629,11 @@ class TestHiddenGroup_diagonal_jacobian: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_gru(self, cls): @@ -656,11 +656,11 @@ def test_gru(self, cls): @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_gru_accuracy(self, cls): diff --git a/brainscale/_etrace_compiler_hidden_pertubation.py b/braintrace/_etrace_compiler_hidden_pertubation.py similarity index 98% rename from brainscale/_etrace_compiler_hidden_pertubation.py rename to braintrace/_etrace_compiler_hidden_pertubation.py index 8a8ccb3..d60e293 100644 --- a/brainscale/_etrace_compiler_hidden_pertubation.py +++ b/braintrace/_etrace_compiler_hidden_pertubation.py @@ -83,12 +83,12 @@ class HiddenPerturbation(NamedTuple): Example:: - >>> import brainscale + >>> import braintrace >>> import brainstate - >>> gru = brainscale.nn.GRUCell(10, 20) + >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) - >>> hidden_perturb = brainscale.add_hidden_perturbation_in_module(gru, inputs) + >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs) """ @@ -152,7 +152,7 @@ def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) -HiddenPerturbation.__module__ = 'brainscale' +HiddenPerturbation.__module__ = 'braintrace' class JaxprEvalForHiddenPerturbation(JaxprEvaluation): @@ -170,7 +170,7 @@ class JaxprEvalForHiddenPerturbation(JaxprEvaluation): The revised closed jaxpr with the perturbations. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_compiler_hidden_pertubation_test.py b/braintrace/_etrace_compiler_hidden_pertubation_test.py similarity index 92% rename from brainscale/_etrace_compiler_hidden_pertubation_test.py rename to braintrace/_etrace_compiler_hidden_pertubation_test.py index 0a72dd9..4dbc569 100644 --- a/brainscale/_etrace_compiler_hidden_pertubation_test.py +++ b/braintrace/_etrace_compiler_hidden_pertubation_test.py @@ -19,9 +19,9 @@ import brainunit as u import pytest -import brainscale -from brainscale._etrace_compiler_hidden_pertubation import add_hidden_perturbation_in_module -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_compiler_hidden_pertubation import add_hidden_perturbation_in_module +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -39,11 +39,11 @@ class TestFindHiddenGroupsFromModule: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_rnn_one_layer(self, cls): diff --git a/brainscale/_etrace_compiler_module_info.py b/braintrace/_etrace_compiler_module_info.py similarity index 99% rename from brainscale/_etrace_compiler_module_info.py rename to braintrace/_etrace_compiler_module_info.py index bd2c390..b657430 100644 --- a/brainscale/_etrace_compiler_module_info.py +++ b/braintrace/_etrace_compiler_module_info.py @@ -217,12 +217,12 @@ class ModuleInfo(NamedTuple): Example:: - >>> import brainscale + >>> import braintrace >>> import brainstate - >>> gru = brainscale.nn.GRUCell(10, 20) + >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) - >>> module_info = brainscale.extract_module_info(gru, inputs) + >>> module_info = braintrace.extract_module_info(gru, inputs) """ # stateful model @@ -400,7 +400,7 @@ def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) -ModuleInfo.__module__ = 'brainscale' +ModuleInfo.__module__ = 'braintrace' def extract_module_info( diff --git a/brainscale/_etrace_compiler_module_info_test.py b/braintrace/_etrace_compiler_module_info_test.py similarity index 87% rename from brainscale/_etrace_compiler_module_info_test.py rename to braintrace/_etrace_compiler_module_info_test.py index c850442..624a78b 100644 --- a/brainscale/_etrace_compiler_module_info_test.py +++ b/braintrace/_etrace_compiler_module_info_test.py @@ -19,8 +19,8 @@ import brainunit as u import pytest -import brainscale -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -38,11 +38,11 @@ class Test_extract_model_info: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_rnn_one_layer(self, cls): @@ -53,7 +53,7 @@ def test_rnn_one_layer(self, cls): brainstate.nn.init_all_states(rnn) input = brainstate.random.rand(n_in) - minfo = brainscale.extract_module_info(rnn, input) + minfo = braintrace.extract_module_info(rnn, input) pprint(minfo) @pytest.mark.parametrize( @@ -81,7 +81,7 @@ def test_snn_single_layer(self, cls): with brainstate.environ.context(dt=0.1 * u.ms): layer = cls(n_in, n_out) brainstate.nn.init_all_states(layer) - minfo = brainscale.extract_module_info(layer, input) + minfo = braintrace.extract_module_info(layer, input) pprint(minfo) @pytest.mark.parametrize( @@ -110,5 +110,5 @@ def test_snn_two_layers(self, cls): with brainstate.environ.context(dt=0.1 * u.ms): layer = brainstate.nn.Sequential(cls(n_in, n_out), cls(n_out, n_out)) brainstate.nn.init_all_states(layer) - minfo = brainscale.extract_module_info(layer, input) + minfo = braintrace.extract_module_info(layer, input) pprint(minfo) diff --git a/brainscale/_etrace_concepts.py b/braintrace/_etrace_concepts.py similarity index 98% rename from brainscale/_etrace_concepts.py rename to braintrace/_etrace_concepts.py index 6eda62c..3e25232 100644 --- a/brainscale/_etrace_concepts.py +++ b/braintrace/_etrace_concepts.py @@ -94,7 +94,7 @@ class ETraceParam(brainstate.ParamState): grad: The gradient type for the ETrace. Default is `adaptive`. name: The name of the weight-operator. """ - __module__ = 'brainscale' + __module__ = 'braintrace' value: brainstate.typing.PyTree # weight op: ETraceOp # operator @@ -203,7 +203,7 @@ class ElemWiseParam(ETraceParam): op: The operator for the ETrace. See :py:class:`ElemWiseOp`. name: The name of the weight-operator. """ - __module__ = 'brainscale' + __module__ = 'braintrace' value: brainstate.typing.PyTree # weight op: ElemWiseOp # operator @@ -263,7 +263,7 @@ class NonTempParam(brainstate.ParamState): value: The value of the parameter. op: The operator for the parameter. See `ETraceOp`. """ - __module__ = 'brainscale' + __module__ = 'braintrace' op: Callable[[X, W], Y] # operator value: brainstate.typing.PyTree # weight @@ -306,7 +306,7 @@ class FakeETraceParam(object): value: The value of the parameter. op: The operator for the parameter. """ - __module__ = 'brainscale' + __module__ = 'braintrace' op: Callable[[X, W], Y] # operator value: brainstate.typing.PyTree # weight @@ -348,7 +348,7 @@ class FakeElemWiseParam(object): op: The operator for the ETrace. See :py:class:`ElemWiseOp`. name: The name of the weight-operator. """ - __module__ = 'brainscale' + __module__ = 'braintrace' op: Callable[[W], Y] # operator value: brainstate.typing.PyTree # weight diff --git a/brainscale/_etrace_concepts_test.py b/braintrace/_etrace_concepts_test.py similarity index 78% rename from brainscale/_etrace_concepts_test.py rename to braintrace/_etrace_concepts_test.py index 742b765..858235b 100644 --- a/brainscale/_etrace_concepts_test.py +++ b/braintrace/_etrace_concepts_test.py @@ -19,8 +19,8 @@ import brainstate import brainunit as u -import brainscale -from brainscale._etrace_concepts import ETraceGrad +import braintrace +from braintrace._etrace_concepts import ETraceGrad class TestETraceState(unittest.TestCase): @@ -38,18 +38,18 @@ def test_check_value(self): class TestETraceGroupState(unittest.TestCase): def test_init(self): value = brainstate.random.randn(10, 10, 5) - state = brainscale.ETraceGroupState(value) + state = braintrace.ETraceGroupState(value) self.assertEqual(state.varshape, value.shape[:-1]) self.assertEqual(state.num_state, value.shape[-1]) def test_get_value(self): value = brainstate.random.randn(10, 10, 5) - state = brainscale.ETraceGroupState(value) + state = braintrace.ETraceGroupState(value) self.assertTrue(u.math.allclose(state.get_value(0), value[..., 0])) def test_set_value(self): value = brainstate.random.randn(10, 10, 5) - state = brainscale.ETraceGroupState(value) + state = braintrace.ETraceGroupState(value) new_value = brainstate.random.randn(10, 10) state.set_value({0: new_value}) self.assertTrue(u.math.allclose(state.get_value(0), new_value)) @@ -59,19 +59,19 @@ class TestETraceTreeState(unittest.TestCase): def test_init(self): value = {'v': brainstate.random.randn(10, 10) * u.mV, 'i': brainstate.random.randn(10, 10) * u.mA} - state = brainscale.ETraceTreeState(value) + state = braintrace.ETraceTreeState(value) self.assertEqual(state.varshape, (10, 10)) self.assertEqual(state.num_state, 2) def test_get_value(self): value = {'v': brainstate.random.randn(10, 10) * u.mV, 'i': brainstate.random.randn(10, 10) * u.mA} - state = brainscale.ETraceTreeState(value) + state = braintrace.ETraceTreeState(value) # print(state.get_value('v'), value['v']) self.assertTrue(u.math.allclose(state.get_value('v'), value['v'], )) def test_set_value(self): value = {'v': brainstate.random.randn(10, 10) * u.mV, 'i': brainstate.random.randn(10, 10) * u.mA} - state = brainscale.ETraceTreeState(value) + state = braintrace.ETraceTreeState(value) new_value = brainstate.random.randn(10, 10) * u.mV state.set_value({'v': new_value}) self.assertTrue(u.math.allclose(state.get_value('v'), new_value)) @@ -80,14 +80,14 @@ def test_set_value(self): # class TestETraceParam(unittest.TestCase): # def test_init(self): # weight = brainstate.random.randn(10, 10) -# op = brainscale.ETraceOp(lambda x, w: x + w) -# param = brainscale.ETraceParam(weight, op) -# self.assertEqual(param.gradient, brainscale.ETraceGrad.adaptive) +# op = braintrace.ETraceOp(lambda x, w: x + w) +# param = braintrace.ETraceParam(weight, op) +# self.assertEqual(param.gradient, braintrace.ETraceGrad.adaptive) # # def test_execute(self): # weight = brainstate.random.randn(10, 10) -# op = brainscale.ETraceOp(lambda x, w: x + w) -# param = brainscale.ETraceParam(weight, op) +# op = braintrace.ETraceOp(lambda x, w: x + w) +# param = braintrace.ETraceParam(weight, op) # x = brainstate.random.randn(10, 10) # result = param.execute(x) # self.assertTrue(u.math.allclose(result, x + weight)) @@ -96,14 +96,14 @@ def test_set_value(self): class TestElemWiseParam(unittest.TestCase): def test_init(self): weight = brainstate.random.randn(10, 10) - op = brainscale.ElemWiseOp(lambda w: w) - param = brainscale.ElemWiseParam(weight, op) + op = braintrace.ElemWiseOp(lambda w: w) + param = braintrace.ElemWiseParam(weight, op) self.assertEqual(param.gradient, ETraceGrad.full) def test_execute(self): weight = brainstate.random.randn(10, 10) - op = brainscale.ElemWiseOp(lambda w: w) - param = brainscale.ElemWiseParam(weight, op) + op = braintrace.ElemWiseOp(lambda w: w) + param = braintrace.ElemWiseParam(weight, op) result = param.execute() self.assertTrue(u.math.allclose(result, weight)) @@ -112,13 +112,13 @@ class TestNonTempParam(unittest.TestCase): def test_init(self): weight = brainstate.random.randn(10, 10) op = lambda x, w: x + w - param = brainscale.NonTempParam(weight, op) + param = braintrace.NonTempParam(weight, op) self.assertEqual(param.value.shape, weight.shape) def test_execute(self): weight = brainstate.random.randn(10, 10) op = lambda x, w: x + w - param = brainscale.NonTempParam(weight, op) + param = braintrace.NonTempParam(weight, op) x = brainstate.random.randn(10, 10) result = param.execute(x) self.assertTrue(u.math.allclose(result, x + weight)) @@ -128,13 +128,13 @@ class TestFakeETraceParam(unittest.TestCase): def test_init(self): weight = brainstate.random.randn(10, 10) op = lambda x, w: x + w - param = brainscale.FakeETraceParam(weight, op) + param = braintrace.FakeETraceParam(weight, op) self.assertEqual(param.value.shape, weight.shape) def test_execute(self): weight = brainstate.random.randn(10, 10) op = lambda x, w: x + w - param = brainscale.FakeETraceParam(weight, op) + param = braintrace.FakeETraceParam(weight, op) x = brainstate.random.randn(10, 10) result = param.execute(x) self.assertTrue(u.math.allclose(result, x + weight)) @@ -143,14 +143,14 @@ def test_execute(self): class TestFakeElemWiseParam(unittest.TestCase): def test_init(self): weight = brainstate.random.randn(10, 10) - op = brainscale.ElemWiseOp(lambda w: w) - param = brainscale.FakeElemWiseParam(weight, op) + op = braintrace.ElemWiseOp(lambda w: w) + param = braintrace.FakeElemWiseParam(weight, op) self.assertEqual(param.value.shape, weight.shape) def test_execute(self): weight = brainstate.random.randn(10, 10) - op = brainscale.ElemWiseOp(lambda w: w) - param = brainscale.FakeElemWiseParam(weight, op) + op = braintrace.ElemWiseOp(lambda w: w) + param = braintrace.FakeElemWiseParam(weight, op) result = param.execute() self.assertTrue(u.math.allclose(result, weight)) diff --git a/brainscale/_etrace_debug_jaxpr2code.py b/braintrace/_etrace_debug_jaxpr2code.py similarity index 100% rename from brainscale/_etrace_debug_jaxpr2code.py rename to braintrace/_etrace_debug_jaxpr2code.py diff --git a/brainscale/_etrace_debug_visualize.py b/braintrace/_etrace_debug_visualize.py similarity index 100% rename from brainscale/_etrace_debug_visualize.py rename to braintrace/_etrace_debug_visualize.py diff --git a/brainscale/_etrace_graph_executor.py b/braintrace/_etrace_graph_executor.py similarity index 99% rename from brainscale/_etrace_graph_executor.py rename to braintrace/_etrace_graph_executor.py index a13bdca..82893ba 100644 --- a/brainscale/_etrace_graph_executor.py +++ b/braintrace/_etrace_graph_executor.py @@ -64,7 +64,7 @@ class ETraceGraphExecutor: model: brainstate.nn.Module The model to build the eligibility trace graph. The models should only define the one-step behavior. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_graph_executor_test.py b/braintrace/_etrace_graph_executor_test.py similarity index 79% rename from brainscale/_etrace_graph_executor_test.py rename to braintrace/_etrace_graph_executor_test.py index 93cd8e3..553fbdc 100644 --- a/brainscale/_etrace_graph_executor_test.py +++ b/braintrace/_etrace_graph_executor_test.py @@ -19,8 +19,8 @@ import brainunit as u import jax.numpy as jnp -import brainscale -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_model_test import ( ALIF_STPExpCu_Dense_Layer, ) @@ -32,26 +32,26 @@ def __init__(self, *args, **kwargs): brainstate.environ.set(dt=0.1 * u.ms) def test_show_lstm_graph(self): - cell = brainscale.nn.LSTMCell(10, 20, activation=jnp.tanh) + cell = braintrace.nn.LSTMCell(10, 20, activation=jnp.tanh) brainstate.nn.init_all_states(cell, 16) - graph = brainscale.ETraceGraphExecutor(cell) + graph = braintrace.ETraceGraphExecutor(cell) graph.compile_graph(jnp.zeros((16, 10))) graph.show_graph() def test_show_gru_graph(self): - cell = brainscale.nn.GRUCell(10, 20, activation=jnp.tanh) + cell = braintrace.nn.GRUCell(10, 20, activation=jnp.tanh) brainstate.nn.init_all_states(cell, 16) - graph = brainscale.ETraceGraphExecutor(cell) + graph = braintrace.ETraceGraphExecutor(cell) graph.compile_graph(jnp.zeros((16, 10))) graph.show_graph() def test_show_lru_graph(self): - cell = brainscale.nn.LRUCell(10, 20) + cell = braintrace.nn.LRUCell(10, 20) brainstate.nn.init_all_states(cell) - graph = brainscale.ETraceGraphExecutor(cell) + graph = braintrace.ETraceGraphExecutor(cell) graph.compile_graph(jnp.zeros((10,))) graph.show_graph() @@ -62,6 +62,6 @@ def test_show_alig_stp_graph(self): net = ALIF_STPExpCu_Dense_Layer(n_in, n_rec) brainstate.nn.init_all_states(net) - graph = brainscale.ETraceGraphExecutor(net) + graph = braintrace.ETraceGraphExecutor(net) graph.compile_graph(brainstate.random.rand(n_in)) graph.show_graph() diff --git a/brainscale/_etrace_input_data.py b/braintrace/_etrace_input_data.py similarity index 98% rename from brainscale/_etrace_input_data.py rename to braintrace/_etrace_input_data.py index 5b7aed6..77354af 100644 --- a/brainscale/_etrace_input_data.py +++ b/braintrace/_etrace_input_data.py @@ -26,7 +26,7 @@ class ETraceInputData: - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__(self, data: Any): """ @@ -72,7 +72,7 @@ class SingleStepData(ETraceInputData): >>> data = SingleStepData(brainstate.random.randn(2, 3)) """ - __module__ = 'brainscale' + __module__ = 'braintrace' @register_pytree_node_class @@ -98,7 +98,7 @@ class MultiStepData(ETraceInputData): ... ) """ - __module__ = 'brainscale' + __module__ = 'braintrace' def is_input(x): diff --git a/brainscale/_etrace_input_data_test.py b/braintrace/_etrace_input_data_test.py similarity index 77% rename from brainscale/_etrace_input_data_test.py rename to braintrace/_etrace_input_data_test.py index 5121e2b..404a7ee 100644 --- a/brainscale/_etrace_input_data_test.py +++ b/braintrace/_etrace_input_data_test.py @@ -19,7 +19,7 @@ import brainstate import jax -import brainscale +import braintrace class TestEtraceInputData(unittest.TestCase): @@ -29,16 +29,16 @@ def test_jittable(self): def f(x): return x.data ** 2 - f(brainscale.SingleStepData(3)) - f(brainscale.MultiStepData(3)) - f(brainscale.SingleStepData(brainstate.random.rand(10))) - f(brainscale.MultiStepData(brainstate.random.rand(10))) + f(braintrace.SingleStepData(3)) + f(braintrace.MultiStepData(3)) + f(braintrace.SingleStepData(brainstate.random.rand(10))) + f(braintrace.MultiStepData(brainstate.random.rand(10))) def test_grad(self): def f(x): return x.data ** 2 - y, grad = jax.value_and_grad(f)(brainscale.SingleStepData(3.)) + y, grad = jax.value_and_grad(f)(braintrace.SingleStepData(3.)) self.assertEqual(y, 9) self.assertEqual(grad.data, 6) @@ -46,6 +46,6 @@ def test_grad2(self): def f(x): return x.data ** 2 - y, grad = jax.value_and_grad(f)(brainscale.MultiStepData(3.)) + y, grad = jax.value_and_grad(f)(braintrace.MultiStepData(3.)) self.assertEqual(y, 9) self.assertEqual(grad.data, 6) diff --git a/brainscale/_etrace_model_test.py b/braintrace/_etrace_model_test.py similarity index 96% rename from brainscale/_etrace_model_test.py rename to braintrace/_etrace_model_test.py index 0e22c4f..347701e 100644 --- a/brainscale/_etrace_model_test.py +++ b/braintrace/_etrace_model_test.py @@ -20,7 +20,7 @@ import brainunit as u import jax.numpy as jnp -import brainscale +import braintrace class IF_Delta_Dense_Layer(brainstate.nn.Module): @@ -32,7 +32,7 @@ def __init__( super().__init__() self.neu = brainpy.state.IF(n_rec, tau=tau_mem, spk_reset=spk_reset, V_th=V_th) w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) - self.syn = brainscale.nn.Linear( + self.syn = braintrace.nn.Linear( n_in + n_rec, n_rec, w_init=w_init * u.mA, @@ -74,7 +74,7 @@ def __init__( weight = jnp.concat([ff_init([self.n_exc_in, n_rec]), rec_init([self.n_exc_rec, n_rec])], axis=0) weight = weight * u.mS self.exe_syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.SignedWLinear(self.n_exc_in + self.n_exc_rec, n_rec, w_init=weight), + comm=braintrace.nn.SignedWLinear(self.n_exc_in + self.n_exc_rec, n_rec, w_init=weight), syn=brainpy.state.Expon.desc(n_rec, tau=tau_syn), out=brainpy.state.COBA.desc(E=3.5 * u.volt), post=self.neu @@ -83,7 +83,7 @@ def __init__( weight = jnp.concat([4 * ff_init([self.n_inh_in, n_rec]), 4 * rec_init([self.n_inh_rec, n_rec])], axis=0) weight = weight * u.mS self.inh_syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.SignedWLinear(self.n_inh_in + self.n_inh_rec, n_rec, w_init=weight), + comm=braintrace.nn.SignedWLinear(self.n_inh_in + self.n_inh_rec, n_rec, w_init=weight), syn=brainpy.state.Expon.desc(n_rec, tau=tau_syn), out=brainpy.state.COBA.desc(E=-0.5 * u.volt), post=self.neu @@ -92,14 +92,14 @@ def __init__( b_init = braintools.init.ZeroInit(unit=u.mS) self.inp_syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear(n_in, n_rec, w_init=ff_init([n_in, n_rec]) * u.mS, b_init=b_init), + comm=braintrace.nn.Linear(n_in, n_rec, w_init=ff_init([n_in, n_rec]) * u.mS, b_init=b_init), syn=brainpy.state.Expon.desc(n_rec, tau=tau_syn), out=brainpy.state.CUBA.desc(), post=self.neu ) self.exe_syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.SignedWLinear( + comm=braintrace.nn.SignedWLinear( self.n_exc_rec, n_rec, w_init=rec_init([self.n_exc_rec, n_rec]) * u.mS), syn=brainpy.state.Expon.desc(n_rec, tau=tau_syn), out=brainpy.state.COBA.desc(E=1.5 * u.volt), @@ -107,7 +107,7 @@ def __init__( ) self.inh_syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.SignedWLinear( + comm=braintrace.nn.SignedWLinear( self.n_inh_rec, n_rec, w_init=4 * rec_init([self.n_inh_rec, n_rec]) * u.mS), syn=brainpy.state.Expon.desc(n_rec, tau=tau_syn), out=brainpy.state.COBA.desc(E=-0.5 * u.volt), @@ -194,7 +194,7 @@ def __init__( super().__init__() self.neu = brainpy.state.LIF(n_rec, tau=tau_mem, spk_fun=spk_fun, spk_reset=spk_reset, V_th=V_th) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -233,7 +233,7 @@ def __init__( self.std_inp = None self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -271,7 +271,7 @@ def __init__( self.stp_inp = brainpy.state.STP(n_in, tau_f=tau_f, tau_d=tau_d) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])]) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -312,7 +312,7 @@ def __init__( V_th=V_th ) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -348,7 +348,7 @@ def __init__( ) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init * u.mV, + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init * u.mV, b_init=braintools.init.ZeroInit(unit=u.mV)), post=self.neu ) @@ -383,7 +383,7 @@ def __init__( self.std_inp = None self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -430,7 +430,7 @@ def __init__( self.stp_inp = brainpy.state.STP(n_in, tau_f=tau_f, tau_d=tau_d) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])]) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) diff --git a/brainscale/_etrace_model_with_group_state.py b/braintrace/_etrace_model_with_group_state.py similarity index 97% rename from brainscale/_etrace_model_with_group_state.py rename to braintrace/_etrace_model_with_group_state.py index 77fac53..bd72d29 100644 --- a/brainscale/_etrace_model_with_group_state.py +++ b/braintrace/_etrace_model_with_group_state.py @@ -22,7 +22,7 @@ import jax import jax.numpy as jnp -import brainscale +import braintrace import braintools @@ -64,7 +64,7 @@ def __init__( self.a_initializer = a_initializer def init_state(self, batch_size: int = None, **kwargs): - self.st = brainscale.ETraceTreeState( + self.st = braintrace.ETraceTreeState( { 'V': braintools.init.param(self.V_initializer, self.varshape, batch_size), 'a': braintools.init.param(self.a_initializer, self.varshape, batch_size), @@ -124,7 +124,7 @@ def __init__( V_th=V_th ) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -165,7 +165,7 @@ def __init__( ) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init * u.mV, + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init * u.mV, b_init=braintools.init.ZeroInit(unit=u.mV)), post=self.neu ) @@ -200,7 +200,7 @@ def __init__( self.std_inp = None self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) @@ -247,7 +247,7 @@ def __init__( self.stp_inp = brainpy.state.STP(n_in, tau_f=tau_f, tau_d=tau_d) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear( + comm=braintrace.nn.Linear( n_in + n_rec, n_rec, jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])]) * u.mS, b_init=braintools.init.ZeroInit(unit=u.mS) diff --git a/brainscale/_etrace_operators.py b/braintrace/_etrace_operators.py similarity index 99% rename from brainscale/_etrace_operators.py rename to braintrace/_etrace_operators.py index 8041ac9..49bcf33 100644 --- a/brainscale/_etrace_operators.py +++ b/braintrace/_etrace_operators.py @@ -73,8 +73,8 @@ def stop_param_gradients(stop_or_not: bool = True): Example:: - >>> import brainscale - >>> with brainscale.stop_param_gradients(): + >>> import braintrace + >>> with braintrace.stop_param_gradients(): >>> # do something Args: @@ -203,7 +203,7 @@ class ETraceOp(brainstate.util.PrettyObject): Args: is_diagonal: bool. Whether the operator is in the hidden diagonal or not. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_operators_test.py b/braintrace/_etrace_operators_test.py similarity index 91% rename from brainscale/_etrace_operators_test.py rename to braintrace/_etrace_operators_test.py index a609e48..b27784a 100644 --- a/brainscale/_etrace_operators_test.py +++ b/braintrace/_etrace_operators_test.py @@ -17,13 +17,13 @@ import jax import jax.numpy as jnp -import brainscale +import braintrace class Test_MatMulOp: def test1(self): - fn = brainscale.MatMulOp(weight_fn=jnp.abs) - fn = brainscale.MatMulOp() + fn = braintrace.MatMulOp(weight_fn=jnp.abs) + fn = braintrace.MatMulOp() x = brainstate.random.rand(10) w = {'weight': brainstate.random.randn(10, 20)} y1 = fn(x, w) @@ -44,7 +44,7 @@ def test1(self): weight = jnp.where(mask, brainstate.random.rand(10, 20), 0.) csr = brainevent.CSR.fromdense(weight) - fn = brainscale.SpMatMulOp(csr, weight_fn=jnp.abs) + fn = braintrace.SpMatMulOp(csr, weight_fn=jnp.abs) x = brainstate.random.rand(10) y1 = fn(x, {'weight': csr.data}) dy = brainstate.random.randn(20) diff --git a/brainscale/_etrace_vjp/__init__.py b/braintrace/_etrace_vjp/__init__.py similarity index 100% rename from brainscale/_etrace_vjp/__init__.py rename to braintrace/_etrace_vjp/__init__.py diff --git a/brainscale/_etrace_vjp/base.py b/braintrace/_etrace_vjp/base.py similarity index 97% rename from brainscale/_etrace_vjp/base.py rename to braintrace/_etrace_vjp/base.py index 7687633..eed224b 100644 --- a/brainscale/_etrace_vjp/base.py +++ b/braintrace/_etrace_vjp/base.py @@ -21,12 +21,12 @@ import jax import jax.numpy as jnp -from brainscale._etrace_algorithms import ( +from braintrace._etrace_algorithms import ( ETraceAlgorithm, ) -from brainscale._etrace_input_data import has_multistep_data -from brainscale._state_managment import assign_state_values_v2 -from brainscale._typing import ( +from braintrace._etrace_input_data import has_multistep_data +from braintrace._state_managment import assign_state_values_v2 +from braintrace._typing import ( PyTree, Outputs, WeightID, @@ -99,7 +99,7 @@ class ETraceVjpAlgorithm(ETraceAlgorithm): """ - __module__ = 'brainscale' + __module__ = 'braintrace' graph_executor: ETraceVjpGraphExecutor def __init__( @@ -157,8 +157,8 @@ def update(self, *args) -> Any: .. code-block:: python - x = [brainscale.SingleStepData(jnp.ones((10,))), - brainscale.SingleStepData(jnp.zeros((10,)))] + x = [braintrace.SingleStepData(jnp.ones((10,))), + braintrace.SingleStepData(jnp.zeros((10,)))] This is the same as the previous case, they are all considered as the input at the current time step. @@ -166,15 +166,15 @@ def update(self, *args) -> Any: .. code-block:: python - x = [brainscale.MultiStepData(jnp.ones((5, 10)), + x = [braintrace.MultiStepData(jnp.ones((5, 10)), jnp.zeros((10,)))] or, .. code-block:: python - x = [brainscale.MultiStepData(jnp.ones((5, 10)), - brainscale.SingleStepData(jnp.zeros((10,)))] + x = [braintrace.MultiStepData(jnp.ones((5, 10)), + braintrace.SingleStepData(jnp.zeros((10,)))] Then, the first input argument is considered as the :py:class:`MultiStepData`, and its data will be fed into the model within five consecutive steps, and the second input argument will be fed diff --git a/brainscale/_etrace_vjp/d_rtrl.py b/braintrace/_etrace_vjp/d_rtrl.py similarity index 98% rename from brainscale/_etrace_vjp/d_rtrl.py rename to braintrace/_etrace_vjp/d_rtrl.py index d303ec0..cad3f5b 100644 --- a/brainscale/_etrace_vjp/d_rtrl.py +++ b/braintrace/_etrace_vjp/d_rtrl.py @@ -21,16 +21,16 @@ import jax import jax.numpy as jnp -from brainscale._etrace_algorithms import EligibilityTrace -from brainscale._etrace_compiler_hid_param_op import HiddenParamOpRelation -from brainscale._etrace_compiler_hidden_group import HiddenGroup -from brainscale._etrace_concepts import ElemWiseParam -from brainscale._etrace_operators import ETraceOp -from brainscale._misc import ( +from braintrace._etrace_algorithms import EligibilityTrace +from braintrace._etrace_compiler_hid_param_op import HiddenParamOpRelation +from braintrace._etrace_compiler_hidden_group import HiddenGroup +from braintrace._etrace_concepts import ElemWiseParam +from braintrace._etrace_operators import ETraceOp +from braintrace._misc import ( etrace_param_key, etrace_df_key, ) -from brainscale._typing import ( +from braintrace._typing import ( PyTree, WeightID, Path, @@ -498,7 +498,7 @@ class ParamDimVjpAlgorithm(ETraceVjpAlgorithm): where $k$ is determined by the data input. name: str, optional The name of the etrace algorithm. - mode: brainscale.mixin.Mode + mode: braintrace.mixin.Mode The computing mode, indicating the batching behavior. """ diff --git a/brainscale/_etrace_vjp/d_rtrl_test.py b/braintrace/_etrace_vjp/d_rtrl_test.py similarity index 86% rename from brainscale/_etrace_vjp/d_rtrl_test.py rename to braintrace/_etrace_vjp/d_rtrl_test.py index 38e1d7a..7a6aa65 100644 --- a/brainscale/_etrace_vjp/d_rtrl_test.py +++ b/braintrace/_etrace_vjp/d_rtrl_test.py @@ -17,8 +17,8 @@ import brainunit as u import pytest -import brainscale -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -37,11 +37,11 @@ class TestDiagOn2: @pytest.mark.parametrize( "cls", [ - # brainscale.nn.GRUCell, - # brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - # brainscale.nn.MGUCell, - # brainscale.nn.MinimalRNNCell, + # braintrace.nn.GRUCell, + # braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + # braintrace.nn.MGUCell, + # braintrace.nn.MinimalRNNCell, ] ) def test_rnn_single_step_vjp(self, cls): @@ -52,7 +52,7 @@ def test_rnn_single_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.ParamDimVjpAlgorithm(model) + algorithm = braintrace.ParamDimVjpAlgorithm(model) algorithm.compile_graph(inputs[0]) outs = brainstate.transform.for_loop(algorithm, inputs) @@ -72,11 +72,11 @@ def grad_single_step_vjp(inp): @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_rnn_multi_step_vjp(self, cls): @@ -87,16 +87,16 @@ def test_rnn_multi_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.ParamDimVjpAlgorithm(model, vjp_method='multi-step') + algorithm = braintrace.ParamDimVjpAlgorithm(model, vjp_method='multi-step') algorithm.compile_graph(inputs[0]) - outs = algorithm(brainscale.MultiStepData(inputs)) + outs = algorithm(braintrace.MultiStepData(inputs)) print(outs.shape) @brainstate.transform.jit def grad_single_step_vjp(inp): return brainstate.transform.grad( - lambda inp: algorithm(brainscale.MultiStepData(inp)).sum(), + lambda inp: algorithm(braintrace.MultiStepData(inp)).sum(), model.states(brainstate.ParamState) )(inp) @@ -134,7 +134,7 @@ def test_snn_single_step_vjp(self, cls): param_states = model.states(brainstate.ParamState).to_dict_values() inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.ParamDimVjpAlgorithm(model) + algorithm = braintrace.ParamDimVjpAlgorithm(model) algorithm.compile_graph(inputs[0]) outs = brainstate.transform.for_loop(algorithm, inputs) @@ -182,16 +182,16 @@ def test_snn_multi_step_vjp(self, cls): param_states = model.states(brainstate.ParamState).to_dict_values() inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.ParamDimVjpAlgorithm(model, vjp_method='multi-step') + algorithm = braintrace.ParamDimVjpAlgorithm(model, vjp_method='multi-step') algorithm.compile_graph(inputs[0]) - outs = algorithm(brainscale.MultiStepData(inputs)) + outs = algorithm(braintrace.MultiStepData(inputs)) print(outs.shape) @brainstate.transform.jit def grad_single_step_vjp(inp): return brainstate.transform.grad( - lambda inp: algorithm(brainscale.MultiStepData(inp)).sum(), + lambda inp: algorithm(braintrace.MultiStepData(inp)).sum(), model.states(brainstate.ParamState) )(inp) diff --git a/brainscale/_etrace_vjp/esd_rtrl.py b/braintrace/_etrace_vjp/esd_rtrl.py similarity index 98% rename from brainscale/_etrace_vjp/esd_rtrl.py rename to braintrace/_etrace_vjp/esd_rtrl.py index 28af321..c475f81 100644 --- a/brainscale/_etrace_vjp/esd_rtrl.py +++ b/braintrace/_etrace_vjp/esd_rtrl.py @@ -33,16 +33,16 @@ import jax import jax.numpy as jnp -from brainscale._etrace_algorithms import EligibilityTrace -from brainscale._etrace_compiler_hid_param_op import HiddenParamOpRelation -from brainscale._etrace_compiler_hidden_group import HiddenGroup -from brainscale._etrace_concepts import ElemWiseParam -from brainscale._misc import ( +from braintrace._etrace_algorithms import EligibilityTrace +from braintrace._etrace_compiler_hid_param_op import HiddenParamOpRelation +from braintrace._etrace_compiler_hidden_group import HiddenGroup +from braintrace._etrace_concepts import ElemWiseParam +from braintrace._misc import ( check_dict_keys, etrace_x_key, etrace_df_key, ) -from brainscale._typing import ( +from braintrace._typing import ( PyTree, WeightVals, Path, @@ -507,11 +507,11 @@ class IODimVjpAlgorithm(ETraceVjpAlgorithm): If it is an integer, it is the number of approximation rank for the algorithm, should be greater than 0. name: str, optional The name of the etrace algorithm. - mode: brainscale.mixin.Mode + mode: braintrace.mixin.Mode The computing mode, indicating the batching information. """ - __module__ = 'brainscale' + __module__ = 'braintrace' # the spatial gradients of the weights etrace_xs: Dict[ETraceX_Key, brainstate.State] diff --git a/brainscale/_etrace_vjp/esd_rtrl_test.py b/braintrace/_etrace_vjp/esd_rtrl_test.py similarity index 85% rename from brainscale/_etrace_vjp/esd_rtrl_test.py rename to braintrace/_etrace_vjp/esd_rtrl_test.py index 1e0add5..e3360e9 100644 --- a/brainscale/_etrace_vjp/esd_rtrl_test.py +++ b/braintrace/_etrace_vjp/esd_rtrl_test.py @@ -17,8 +17,8 @@ import brainunit as u import pytest -import brainscale -from brainscale._etrace_model_test import ( +import braintrace +from braintrace._etrace_model_test import ( IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, ALIF_ExpCo_Dense_Layer, @@ -37,11 +37,11 @@ class TestDiagOn: @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_rnn_single_step_vjp(self, cls): @@ -52,7 +52,7 @@ def test_rnn_single_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.9) + algorithm = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.9) algorithm.compile_graph(inputs[0]) outs = brainstate.transform.for_loop(algorithm, inputs) @@ -72,11 +72,11 @@ def grad_single_step_vjp(inp): @pytest.mark.parametrize( "cls", [ - brainscale.nn.GRUCell, - brainscale.nn.LSTMCell, - brainscale.nn.LRUCell, - brainscale.nn.MGUCell, - brainscale.nn.MinimalRNNCell, + braintrace.nn.GRUCell, + braintrace.nn.LSTMCell, + braintrace.nn.LRUCell, + braintrace.nn.MGUCell, + braintrace.nn.MinimalRNNCell, ] ) def test_rnn_multi_step_vjp(self, cls): @@ -87,16 +87,16 @@ def test_rnn_multi_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.9, vjp_method='multi-step') + algorithm = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.9, vjp_method='multi-step') algorithm.compile_graph(inputs[0]) - outs = algorithm(brainscale.MultiStepData(inputs)) + outs = algorithm(braintrace.MultiStepData(inputs)) print(outs.shape) @brainstate.transform.jit def grad_single_step_vjp(inp): return brainstate.transform.grad( - lambda inp: algorithm(brainscale.MultiStepData(inp)).sum(), + lambda inp: algorithm(braintrace.MultiStepData(inp)).sum(), model.states(brainstate.ParamState) )(inp) @@ -130,7 +130,7 @@ def test_snn_single_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.9) + algorithm = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.9) algorithm.compile_graph(inputs[0]) outs = brainstate.transform.for_loop(algorithm, inputs) @@ -173,16 +173,16 @@ def test_snn_multi_step_vjp(self, cls): model = brainstate.nn.init_all_states(model) inputs = brainstate.random.randn(n_seq, n_in) - algorithm = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.9, vjp_method='multi-step') + algorithm = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.9, vjp_method='multi-step') algorithm.compile_graph(inputs[0]) - outs = algorithm(brainscale.MultiStepData(inputs)) + outs = algorithm(braintrace.MultiStepData(inputs)) print(outs.shape) @brainstate.transform.jit def grad_single_step_vjp(inp): return brainstate.transform.grad( - lambda inp: algorithm(brainscale.MultiStepData(inp)).sum(), + lambda inp: algorithm(braintrace.MultiStepData(inp)).sum(), model.states(brainstate.ParamState) )(inp) diff --git a/brainscale/_etrace_vjp/graph_executor.py b/braintrace/_etrace_vjp/graph_executor.py similarity index 98% rename from brainscale/_etrace_vjp/graph_executor.py rename to braintrace/_etrace_vjp/graph_executor.py index 011be5a..0a66c21 100644 --- a/brainscale/_etrace_vjp/graph_executor.py +++ b/braintrace/_etrace_vjp/graph_executor.py @@ -47,25 +47,25 @@ from jax.interpreters import partial_eval as pe from jax.tree_util import register_pytree_node_class -from brainscale._compatible_imports import Var -from brainscale._etrace_compiler_graph import compile_etrace_graph -from brainscale._etrace_compiler_hidden_group import HiddenGroup -from brainscale._etrace_graph_executor import ETraceGraphExecutor -from brainscale._etrace_input_data import ( +from braintrace._compatible_imports import Var +from braintrace._etrace_compiler_graph import compile_etrace_graph +from braintrace._etrace_compiler_hidden_group import HiddenGroup +from braintrace._etrace_graph_executor import ETraceGraphExecutor +from braintrace._etrace_input_data import ( get_single_step_data, split_input_data_types, merge_data, has_multistep_data, ) -from brainscale._misc import ( +from braintrace._misc import ( etrace_df_key, etrace_x_key, ) -from brainscale._state_managment import ( +from braintrace._state_managment import ( assign_dict_state_values, split_dict_states_v2 ) -from brainscale._typing import ( +from braintrace._typing import ( Outputs, ETraceVals, StateVals, @@ -148,7 +148,7 @@ class ETraceVjpGraphExecutor(ETraceGraphExecutor): - "multi-step": The VJP is computed at multiple time steps, i.e., $\partial L^t/\partial h^{t-k}$, where $k$ is determined by the data input. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def __init__( self, diff --git a/brainscale/_etrace_vjp/graph_executor_test.py b/braintrace/_etrace_vjp/graph_executor_test.py similarity index 83% rename from brainscale/_etrace_vjp/graph_executor_test.py rename to braintrace/_etrace_vjp/graph_executor_test.py index 4c29c69..0d0756b 100644 --- a/brainscale/_etrace_vjp/graph_executor_test.py +++ b/braintrace/_etrace_vjp/graph_executor_test.py @@ -21,7 +21,7 @@ import jax.numpy as jnp import numpy as np -import brainscale +import braintrace class TestETraceVjpGraphExecutor(unittest.TestCase): @@ -31,39 +31,39 @@ def in_size(self): return 3 def setUp(self): - self.model = brainscale.nn.GRUCell(self.in_size, 4) + self.model = braintrace.nn.GRUCell(self.in_size, 4) brainstate.nn.init_all_states(self.model) brainstate.environ.set(dt=0.1 * u.ms) def test_initialization(self): - executor = brainscale.ETraceVjpGraphExecutor(self.model) + executor = braintrace.ETraceVjpGraphExecutor(self.model) self.assertEqual(executor.vjp_method, 'single-step') - executor = brainscale.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') + executor = braintrace.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') self.assertEqual(executor.vjp_method, 'multi-step') def test_invalid_vjp_method(self): with self.assertRaises(AssertionError): - brainscale.ETraceVjpGraphExecutor(self.model, vjp_method='invalid') + braintrace.ETraceVjpGraphExecutor(self.model, vjp_method='invalid') def test_is_single_step_vjp(self): - executor = brainscale.ETraceVjpGraphExecutor(self.model) + executor = braintrace.ETraceVjpGraphExecutor(self.model) self.assertTrue(executor.is_single_step_vjp) self.assertFalse(executor.is_multi_step_vjp) def test_is_multi_step_vjp(self): - executor = brainscale.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') + executor = braintrace.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') self.assertFalse(executor.is_single_step_vjp) self.assertTrue(executor.is_multi_step_vjp) def test_compile_graph(self): - executor = brainscale.ETraceVjpGraphExecutor(self.model) + executor = braintrace.ETraceVjpGraphExecutor(self.model) x = jnp.ones((self.in_size,)) executor.compile_graph(x) self.assertIsNotNone(executor._compiled_graph) def test_solve_h2w_h2h_jacobian(self): - executor = brainscale.ETraceVjpGraphExecutor(self.model) + executor = braintrace.ETraceVjpGraphExecutor(self.model) x = jnp.ones((self.in_size,)) executor.compile_graph(x) @@ -76,8 +76,8 @@ def test_solve_h2w_h2h_jacobian(self): self.assertIsInstance(h2h_jacobian, list) def test_single_step_vs_multi_step(self): - single_step_executor = brainscale.ETraceVjpGraphExecutor(self.model, vjp_method='single-step') - multi_step_executor = brainscale.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') + single_step_executor = braintrace.ETraceVjpGraphExecutor(self.model, vjp_method='single-step') + multi_step_executor = braintrace.ETraceVjpGraphExecutor(self.model, vjp_method='multi-step') x = jnp.ones((self.in_size,)) single_step_executor.compile_graph(x) diff --git a/brainscale/_etrace_vjp/hybrid.py b/braintrace/_etrace_vjp/hybrid.py similarity index 98% rename from brainscale/_etrace_vjp/hybrid.py rename to braintrace/_etrace_vjp/hybrid.py index ccf4e47..85c109d 100644 --- a/brainscale/_etrace_vjp/hybrid.py +++ b/braintrace/_etrace_vjp/hybrid.py @@ -21,19 +21,19 @@ import brainunit as u import jax -from brainscale._etrace_compiler_hid_param_op import HiddenParamOpRelation -from brainscale._etrace_compiler_hidden_group import HiddenGroup -from brainscale._etrace_concepts import ( +from braintrace._etrace_compiler_hid_param_op import HiddenParamOpRelation +from braintrace._etrace_compiler_hidden_group import HiddenGroup +from braintrace._etrace_concepts import ( ETraceParam, ElemWiseParam, ETraceGrad, ) -from brainscale._misc import ( +from braintrace._misc import ( etrace_x_key, etrace_param_key, etrace_df_key, ) -from brainscale._typing import ( +from braintrace._typing import ( PyTree, Path, ETraceX_Key, @@ -177,7 +177,7 @@ class HybridDimVjpAlgorithm(ETraceVjpAlgorithm): The exponential smoothing factor for the eligibility trace. If it is a float, it is the decay factor, should be in the range of (0, 1). If it is an integer, it is the number of approximation rank for the algorithm, should be greater than 0. - mode: brainscale.mixin.Mode + mode: braintrace.mixin.Mode The computing mode, indicating the batching behavior. """ diff --git a/brainscale/_etrace_vjp/misc.py b/braintrace/_etrace_vjp/misc.py similarity index 100% rename from brainscale/_etrace_vjp/misc.py rename to braintrace/_etrace_vjp/misc.py diff --git a/brainscale/_grad_exponential.py b/braintrace/_grad_exponential.py similarity index 100% rename from brainscale/_grad_exponential.py rename to braintrace/_grad_exponential.py diff --git a/brainscale/_misc.py b/braintrace/_misc.py similarity index 98% rename from brainscale/_misc.py rename to braintrace/_misc.py index c63171a..9da2970 100644 --- a/brainscale/_misc.py +++ b/braintrace/_misc.py @@ -212,7 +212,7 @@ def remove_units(xs): ) -git_issue_addr = 'https://github.com/chaobrain/brainscale/issues' +git_issue_addr = 'https://github.com/chaobrain/braintrace/issues' def deprecation_getattr(module, deprecations): @@ -258,7 +258,7 @@ class NotSupportedError(Exception): functionality is not supported within the context of the application. """ - __module__ = 'brainscale' + __module__ = 'braintrace' class CompilationError(Exception): @@ -268,7 +268,7 @@ class CompilationError(Exception): This exception is used to indicate that a compilation error has occurred within the context of the application. """ - __module__ = 'brainscale' + __module__ = 'braintrace' def state_traceback(states: Sequence[brainstate.State]): @@ -303,7 +303,7 @@ def state_traceback(states: Sequence[brainstate.State]): return '\n'.join(state_info) -def set_module_as(module: str = 'brainscale'): +def set_module_as(module: str = 'braintrace'): """ Decorator to set the module attribute of a function. @@ -313,7 +313,7 @@ def set_module_as(module: str = 'brainscale'): Parameters ---------- module : str, optional - The name of the module to set for the function, by default 'brainscale'. + The name of the module to set for the function, by default 'braintrace'. Returns ------- diff --git a/brainscale/_state_managment.py b/braintrace/_state_managment.py similarity index 100% rename from brainscale/_state_managment.py rename to braintrace/_state_managment.py diff --git a/brainscale/_typing.py b/braintrace/_typing.py similarity index 100% rename from brainscale/_typing.py rename to braintrace/_typing.py diff --git a/brainscale/nn/__init__.py b/braintrace/nn/__init__.py similarity index 95% rename from brainscale/nn/__init__.py rename to braintrace/nn/__init__.py index 3e20adc..8decc64 100644 --- a/brainscale/nn/__init__.py +++ b/braintrace/nn/__init__.py @@ -36,7 +36,7 @@ def __getattr__(name): 'Expon', 'Alpha', 'DualExpon', 'STP', 'STD', ]: warnings.warn( - f'brainscale.nn.{name} is deprecated. Use brainstate.state.{name} instead.', + f'braintrace.nn.{name} is deprecated. Use brainstate.state.{name} instead.', DeprecationWarning, stacklevel=2 ) @@ -58,7 +58,7 @@ def __getattr__(name): 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', ]: warnings.warn( - f'brainscale.nn.{name} is deprecated. Use brainstate.nn.{name} instead.', + f'braintrace.nn.{name} is deprecated. Use brainstate.nn.{name} instead.', DeprecationWarning, stacklevel=2 ) diff --git a/brainscale/nn/_conv.py b/braintrace/nn/_conv.py similarity index 96% rename from brainscale/nn/_conv.py rename to braintrace/nn/_conv.py index c0c506b..2cc99c5 100644 --- a/brainscale/nn/_conv.py +++ b/braintrace/nn/_conv.py @@ -22,9 +22,9 @@ import jax from braintools import init -from brainscale._etrace_concepts import ETraceParam -from brainscale._etrace_operators import ConvOp -from brainscale._typing import ArrayLike +from braintrace._etrace_concepts import ETraceParam +from braintrace._etrace_operators import ConvOp +from braintrace._typing import ArrayLike __all__ = [ 'Conv1d', @@ -248,11 +248,11 @@ class Conv1d(_Conv): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a 1D convolution layer - >>> conv1d = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + >>> conv1d = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) >>> >>> # Input with batch size 4 >>> x = brainstate.random.randn(4, 10, 3) @@ -260,7 +260,7 @@ class Conv1d(_Conv): >>> print(y.shape) (4, 10, 16) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 1 @@ -313,11 +313,11 @@ class Conv2d(_Conv): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a 2D convolution layer - >>> conv2d = brainscale.nn.Conv2d(in_size=(28, 28, 1), out_channels=32, kernel_size=3, stride=1) + >>> conv2d = braintrace.nn.Conv2d(in_size=(28, 28, 1), out_channels=32, kernel_size=3, stride=1) >>> >>> # Input with batch size 8 >>> x = brainstate.random.randn(8, 28, 28, 1) @@ -325,7 +325,7 @@ class Conv2d(_Conv): >>> print(y.shape) (8, 28, 28, 32) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 2 @@ -378,11 +378,11 @@ class Conv3d(_Conv): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a 3D convolution layer - >>> conv3d = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, stride=2) + >>> conv3d = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, stride=2) >>> >>> # Input with batch size 2 >>> x = brainstate.random.randn(2, 16, 16, 16, 3) @@ -390,6 +390,6 @@ class Conv3d(_Conv): >>> print(y.shape) (2, 8, 8, 8, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 3 diff --git a/brainscale/nn/_conv_test.py b/braintrace/nn/_conv_test.py similarity index 86% rename from brainscale/nn/_conv_test.py rename to braintrace/nn/_conv_test.py index 6be19d4..71a8d7f 100644 --- a/brainscale/nn/_conv_test.py +++ b/braintrace/nn/_conv_test.py @@ -28,8 +28,8 @@ jnp = pytest.importorskip("jax.numpy") braintools = pytest.importorskip("braintools") init = braintools.init -brainscale = pytest.importorskip("brainscale") -from brainscale.nn._conv import to_dimension_numbers, replicate +braintrace = pytest.importorskip("braintrace") +from braintrace.nn._conv import to_dimension_numbers, replicate class TestUtilityFunctions: @@ -113,7 +113,7 @@ class TestConv1d: def test_conv1d_basic_creation(self): """Test basic Conv1d layer creation.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) assert conv.in_channels == 3 assert conv.out_channels == 16 assert conv.kernel_size == (3,) @@ -121,21 +121,21 @@ def test_conv1d_basic_creation(self): def test_conv1d_forward_with_batch(self): """Test Conv1d forward pass with batch dimension.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.shape == (4, 10, 16) def test_conv1d_forward_without_batch(self): """Test Conv1d forward pass without batch dimension.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) x = brainstate.random.randn(10, 3) y = conv(x) assert y.shape == (10, 16) def test_conv1d_different_strides(self): """Test Conv1d with different stride values.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, stride=2) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, stride=2) x = brainstate.random.randn(4, 10, 3) y = conv(x) # With stride=2 and SAME padding, output size should be ceil(10/2) = 5 @@ -143,7 +143,7 @@ def test_conv1d_different_strides(self): def test_conv1d_valid_padding(self): """Test Conv1d with VALID padding.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding='VALID') + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding='VALID') x = brainstate.random.randn(4, 10, 3) y = conv(x) # With VALID padding and kernel_size=3, output size should be 10-3+1 = 8 @@ -151,7 +151,7 @@ def test_conv1d_valid_padding(self): def test_conv1d_same_padding(self): """Test Conv1d with SAME padding.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding='SAME') + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding='SAME') x = brainstate.random.randn(4, 10, 3) y = conv(x) # With SAME padding, output size should be same as input @@ -159,7 +159,7 @@ def test_conv1d_same_padding(self): def test_conv1d_explicit_padding_int(self): """Test Conv1d with explicit integer padding.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding=1) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding=1) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.ndim == 3 @@ -168,7 +168,7 @@ def test_conv1d_explicit_padding_int(self): def test_conv1d_explicit_padding_tuple(self): """Test Conv1d with explicit tuple padding.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding=(1, 1)) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, padding=(1, 1)) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.ndim == 3 @@ -177,7 +177,7 @@ def test_conv1d_explicit_padding_tuple(self): def test_conv1d_with_bias(self): """Test Conv1d with bias initialization.""" - conv = brainscale.nn.Conv1d( + conv = braintrace.nn.Conv1d( in_size=(10, 3), out_channels=16, kernel_size=3, @@ -189,14 +189,14 @@ def test_conv1d_with_bias(self): def test_conv1d_without_bias(self): """Test Conv1d without bias.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, b_init=None) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, b_init=None) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.shape == (4, 10, 16) def test_conv1d_with_groups(self): """Test Conv1d with grouped convolution.""" - conv = brainscale.nn.Conv1d(in_size=(10, 4), out_channels=8, kernel_size=3, groups=2) + conv = braintrace.nn.Conv1d(in_size=(10, 4), out_channels=8, kernel_size=3, groups=2) x = brainstate.random.randn(4, 10, 4) y = conv(x) assert y.shape == (4, 10, 8) @@ -204,7 +204,7 @@ def test_conv1d_with_groups(self): def test_conv1d_depthwise(self): """Test Conv1d with depthwise convolution (groups = in_channels).""" in_channels = 4 - conv = brainscale.nn.Conv1d( + conv = braintrace.nn.Conv1d( in_size=(10, in_channels), out_channels=in_channels, kernel_size=3, @@ -216,7 +216,7 @@ def test_conv1d_depthwise(self): # def test_conv1d_lhs_dilation(self): # """Test Conv1d with lhs_dilation (atrous convolution on input).""" - # conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, lhs_dilation=2) + # conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, lhs_dilation=2) # x = brainstate.random.randn(4, 10, 3) # y = conv(x) # assert y.ndim == 3 @@ -225,7 +225,7 @@ def test_conv1d_depthwise(self): def test_conv1d_rhs_dilation(self): """Test Conv1d with rhs_dilation (atrous convolution on kernel).""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, rhs_dilation=2) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, rhs_dilation=2) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.ndim == 3 @@ -234,7 +234,7 @@ def test_conv1d_rhs_dilation(self): def test_conv1d_custom_initializers(self): """Test Conv1d with custom weight and bias initializers.""" - conv = brainscale.nn.Conv1d( + conv = braintrace.nn.Conv1d( in_size=(10, 3), out_channels=16, kernel_size=3, @@ -249,40 +249,40 @@ def test_conv1d_with_weight_mask(self): """Test Conv1d with weight mask.""" kernel_shape = (3, 3, 16) # (kernel_size, in_channels, out_channels) mask = jnp.ones(kernel_shape) - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, w_mask=mask) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, w_mask=mask) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.shape == (4, 10, 16) def test_conv1d_kernel_shape(self): """Test Conv1d kernel shape is correct.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=5) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=5) assert conv.kernel_shape == (5, 3, 16) def test_conv1d_input_validation_wrong_ndim(self): """Test Conv1d input validation rejects wrong number of dimensions.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) x = brainstate.random.randn(10) # 1D input - should fail with pytest.raises(ValueError): conv(x) def test_conv1d_input_validation_wrong_shape(self): """Test Conv1d input validation rejects wrong shape.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) x = brainstate.random.randn(4, 12, 3) # Wrong spatial dimension with pytest.raises(ValueError): conv(x) def test_conv1d_sequence_stride(self): """Test Conv1d with stride as sequence.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, stride=[2]) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3, stride=[2]) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.shape == (4, 5, 16) def test_conv1d_sequence_kernel_size(self): """Test Conv1d with kernel_size as sequence.""" - conv = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=[3]) + conv = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=[3]) x = brainstate.random.randn(4, 10, 3) y = conv(x) assert y.shape == (4, 10, 16) @@ -293,7 +293,7 @@ class TestConv2d: def test_conv2d_basic_creation(self): """Test basic Conv2d layer creation.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) assert conv.in_channels == 3 assert conv.out_channels == 32 assert conv.kernel_size == (3, 3) @@ -301,21 +301,21 @@ def test_conv2d_basic_creation(self): def test_conv2d_forward_with_batch(self): """Test Conv2d forward pass with batch dimension.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.shape == (8, 28, 28, 32) def test_conv2d_forward_without_batch(self): """Test Conv2d forward pass without batch dimension.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) x = brainstate.random.randn(28, 28, 3) y = conv(x) assert y.shape == (28, 28, 32) def test_conv2d_different_strides(self): """Test Conv2d with different stride values.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, stride=2) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, stride=2) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) # With stride=2 and SAME padding, output size should be ceil(28/2) = 14 @@ -323,7 +323,7 @@ def test_conv2d_different_strides(self): def test_conv2d_asymmetric_strides(self): """Test Conv2d with different stride values for each dimension.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, stride=(2, 1)) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, stride=(2, 1)) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) # With stride=(2,1) and SAME padding @@ -331,7 +331,7 @@ def test_conv2d_asymmetric_strides(self): def test_conv2d_valid_padding(self): """Test Conv2d with VALID padding.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding='VALID') + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding='VALID') x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) # With VALID padding and kernel_size=3, output size should be 28-3+1 = 26 @@ -339,7 +339,7 @@ def test_conv2d_valid_padding(self): def test_conv2d_same_padding(self): """Test Conv2d with SAME padding.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding='SAME') + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding='SAME') x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) # With SAME padding, output size should be same as input @@ -347,7 +347,7 @@ def test_conv2d_same_padding(self): def test_conv2d_explicit_padding_int(self): """Test Conv2d with explicit integer padding.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding=1) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding=1) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.ndim == 4 @@ -356,7 +356,7 @@ def test_conv2d_explicit_padding_int(self): def test_conv2d_explicit_padding_tuple(self): """Test Conv2d with explicit tuple padding.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding=(1, 1)) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, padding=(1, 1)) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.ndim == 4 @@ -365,7 +365,7 @@ def test_conv2d_explicit_padding_tuple(self): def test_conv2d_explicit_padding_sequence(self): """Test Conv2d with explicit sequence of tuples padding.""" - conv = brainscale.nn.Conv2d( + conv = braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, kernel_size=3, @@ -379,7 +379,7 @@ def test_conv2d_explicit_padding_sequence(self): def test_conv2d_with_bias(self): """Test Conv2d with bias initialization.""" - conv = brainscale.nn.Conv2d( + conv = braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, kernel_size=3, @@ -391,14 +391,14 @@ def test_conv2d_with_bias(self): def test_conv2d_without_bias(self): """Test Conv2d without bias.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, b_init=None) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, b_init=None) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.shape == (8, 28, 28, 32) def test_conv2d_with_groups(self): """Test Conv2d with grouped convolution.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 4), out_channels=8, kernel_size=3, groups=2) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 4), out_channels=8, kernel_size=3, groups=2) x = brainstate.random.randn(8, 28, 28, 4) y = conv(x) assert y.shape == (8, 28, 28, 8) @@ -406,7 +406,7 @@ def test_conv2d_with_groups(self): def test_conv2d_depthwise(self): """Test Conv2d with depthwise convolution (groups = in_channels).""" in_channels = 4 - conv = brainscale.nn.Conv2d( + conv = braintrace.nn.Conv2d( in_size=(28, 28, in_channels), out_channels=in_channels, kernel_size=3, @@ -418,7 +418,7 @@ def test_conv2d_depthwise(self): # def test_conv2d_lhs_dilation(self): # """Test Conv2d with lhs_dilation (atrous convolution on input).""" - # conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, lhs_dilation=2) + # conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, lhs_dilation=2) # x = brainstate.random.randn(8, 28, 28, 3) # y = conv(x) # assert y.ndim == 4 @@ -427,7 +427,7 @@ def test_conv2d_depthwise(self): def test_conv2d_rhs_dilation(self): """Test Conv2d with rhs_dilation (atrous convolution on kernel).""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, rhs_dilation=2) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, rhs_dilation=2) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.ndim == 4 @@ -436,7 +436,7 @@ def test_conv2d_rhs_dilation(self): def test_conv2d_asymmetric_kernel(self): """Test Conv2d with asymmetric kernel size.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=(3, 5)) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=(3, 5)) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.shape == (8, 28, 28, 32) @@ -444,7 +444,7 @@ def test_conv2d_asymmetric_kernel(self): def test_conv2d_custom_initializers(self): """Test Conv2d with custom weight and bias initializers.""" - conv = brainscale.nn.Conv2d( + conv = braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, kernel_size=3, @@ -459,26 +459,26 @@ def test_conv2d_with_weight_mask(self): """Test Conv2d with weight mask.""" kernel_shape = (3, 3, 3, 32) # (H, W, in_channels, out_channels) mask = jnp.ones(kernel_shape) - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, w_mask=mask) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3, w_mask=mask) x = brainstate.random.randn(8, 28, 28, 3) y = conv(x) assert y.shape == (8, 28, 28, 32) def test_conv2d_kernel_shape(self): """Test Conv2d kernel shape is correct.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=5) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=5) assert conv.kernel_shape == (5, 5, 3, 32) def test_conv2d_input_validation_wrong_ndim(self): """Test Conv2d input validation rejects wrong number of dimensions.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) x = brainstate.random.randn(28, 3) # 2D input - should fail with pytest.raises(ValueError): conv(x) def test_conv2d_input_validation_wrong_shape(self): """Test Conv2d input validation rejects wrong shape.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) x = brainstate.random.randn(8, 32, 32, 3) # Wrong spatial dimensions with pytest.raises(ValueError): conv(x) @@ -489,7 +489,7 @@ class TestConv3d: def test_conv3d_basic_creation(self): """Test basic Conv3d layer creation.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) assert conv.in_channels == 3 assert conv.out_channels == 64 assert conv.kernel_size == (3, 3, 3) @@ -497,21 +497,21 @@ def test_conv3d_basic_creation(self): def test_conv3d_forward_with_batch(self): """Test Conv3d forward pass with batch dimension.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) x = brainstate.random.randn(2, 16, 16, 16, 3) y = conv(x) assert y.shape == (2, 16, 16, 16, 64) def test_conv3d_forward_without_batch(self): """Test Conv3d forward pass without batch dimension.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) x = brainstate.random.randn(16, 16, 16, 3) y = conv(x) assert y.shape == (16, 16, 16, 64) def test_conv3d_different_strides(self): """Test Conv3d with different stride values.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, stride=2) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, stride=2) x = brainstate.random.randn(2, 16, 16, 16, 3) y = conv(x) # With stride=2 and SAME padding, output size should be ceil(16/2) = 8 @@ -519,7 +519,7 @@ def test_conv3d_different_strides(self): def test_conv3d_asymmetric_strides(self): """Test Conv3d with different stride values for each dimension.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -532,7 +532,7 @@ def test_conv3d_asymmetric_strides(self): def test_conv3d_valid_padding(self): """Test Conv3d with VALID padding.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -545,7 +545,7 @@ def test_conv3d_valid_padding(self): def test_conv3d_same_padding(self): """Test Conv3d with SAME padding.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -558,7 +558,7 @@ def test_conv3d_same_padding(self): def test_conv3d_explicit_padding_int(self): """Test Conv3d with explicit integer padding.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -572,7 +572,7 @@ def test_conv3d_explicit_padding_int(self): def test_conv3d_explicit_padding_tuple(self): """Test Conv3d with explicit tuple padding.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -586,7 +586,7 @@ def test_conv3d_explicit_padding_tuple(self): def test_conv3d_explicit_padding_sequence(self): """Test Conv3d with explicit sequence of tuples padding.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -600,7 +600,7 @@ def test_conv3d_explicit_padding_sequence(self): def test_conv3d_with_bias(self): """Test Conv3d with bias initialization.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -612,7 +612,7 @@ def test_conv3d_with_bias(self): def test_conv3d_without_bias(self): """Test Conv3d without bias.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -624,7 +624,7 @@ def test_conv3d_without_bias(self): def test_conv3d_with_groups(self): """Test Conv3d with grouped convolution.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 4), out_channels=8, kernel_size=3, @@ -637,7 +637,7 @@ def test_conv3d_with_groups(self): def test_conv3d_depthwise(self): """Test Conv3d with depthwise convolution (groups = in_channels).""" in_channels = 4 - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, in_channels), out_channels=in_channels, kernel_size=3, @@ -649,7 +649,7 @@ def test_conv3d_depthwise(self): # def test_conv3d_lhs_dilation(self): # """Test Conv3d with lhs_dilation (atrous convolution on input).""" - # conv = brainscale.nn.Conv3d( + # conv = braintrace.nn.Conv3d( # in_size=(16, 16, 16, 3), # out_channels=64, # kernel_size=3, @@ -663,7 +663,7 @@ def test_conv3d_depthwise(self): def test_conv3d_rhs_dilation(self): """Test Conv3d with rhs_dilation (atrous convolution on kernel).""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -677,7 +677,7 @@ def test_conv3d_rhs_dilation(self): def test_conv3d_asymmetric_kernel(self): """Test Conv3d with asymmetric kernel size.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=(3, 5, 3) @@ -689,7 +689,7 @@ def test_conv3d_asymmetric_kernel(self): def test_conv3d_custom_initializers(self): """Test Conv3d with custom weight and bias initializers.""" - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -704,7 +704,7 @@ def test_conv3d_with_weight_mask(self): """Test Conv3d with weight mask.""" kernel_shape = (3, 3, 3, 3, 64) # (H, W, D, in_channels, out_channels) mask = jnp.ones(kernel_shape) - conv = brainscale.nn.Conv3d( + conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, @@ -716,19 +716,19 @@ def test_conv3d_with_weight_mask(self): def test_conv3d_kernel_shape(self): """Test Conv3d kernel shape is correct.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=5) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=5) assert conv.kernel_shape == (5, 5, 5, 3, 64) def test_conv3d_input_validation_wrong_ndim(self): """Test Conv3d input validation rejects wrong number of dimensions.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) x = brainstate.random.randn(16, 16, 3) # 3D input - should fail with pytest.raises(ValueError): conv(x) def test_conv3d_input_validation_wrong_shape(self): """Test Conv3d input validation rejects wrong shape.""" - conv = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) + conv = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3) x = brainstate.random.randn(2, 20, 20, 20, 3) # Wrong spatial dimensions with pytest.raises(ValueError): conv(x) @@ -740,17 +740,17 @@ class TestConvEdgeCases: def test_conv_invalid_groups_out_channels(self): """Test that out_channels must be divisible by groups.""" with pytest.raises(AssertionError): - brainscale.nn.Conv2d(in_size=(28, 28, 4), out_channels=9, kernel_size=3, groups=2) + braintrace.nn.Conv2d(in_size=(28, 28, 4), out_channels=9, kernel_size=3, groups=2) def test_conv_invalid_groups_in_channels(self): """Test that in_channels must be divisible by groups.""" with pytest.raises(AssertionError): - brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=8, kernel_size=3, groups=2) + braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=8, kernel_size=3, groups=2) def test_conv_invalid_padding_string(self): """Test that only SAME and VALID are accepted as string padding.""" with pytest.raises(AssertionError): - brainscale.nn.Conv2d( + braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, kernel_size=3, @@ -761,7 +761,7 @@ def test_conv_invalid_padding_wrong_length(self): """Test that padding sequence with wrong number of tuples raises error.""" # Padding with 3 tuples for 2D conv should fail (needs 1 or 2 tuples) with pytest.raises(ValueError): - brainscale.nn.Conv2d( + braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, kernel_size=3, @@ -771,7 +771,7 @@ def test_conv_invalid_padding_wrong_length(self): def test_conv_invalid_in_size_length(self): """Test that in_size must have correct length.""" with pytest.raises(AssertionError): - brainscale.nn.Conv2d( + braintrace.nn.Conv2d( in_size=(28, 3), # Should be 3D for Conv2d out_channels=32, kernel_size=3 @@ -783,8 +783,8 @@ class TestConvIntegration: def test_conv1d_stacked_layers(self): """Test stacking multiple Conv1d layers.""" - conv1 = brainscale.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) - conv2 = brainscale.nn.Conv1d(in_size=(10, 16), out_channels=32, kernel_size=3) + conv1 = braintrace.nn.Conv1d(in_size=(10, 3), out_channels=16, kernel_size=3) + conv2 = braintrace.nn.Conv1d(in_size=(10, 16), out_channels=32, kernel_size=3) x = brainstate.random.randn(4, 10, 3) y1 = conv1(x) @@ -795,8 +795,8 @@ def test_conv1d_stacked_layers(self): def test_conv2d_stacked_layers(self): """Test stacking multiple Conv2d layers.""" - conv1 = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) - conv2 = brainscale.nn.Conv2d(in_size=(28, 28, 32), out_channels=64, kernel_size=3) + conv1 = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv2 = braintrace.nn.Conv2d(in_size=(28, 28, 32), out_channels=64, kernel_size=3) x = brainstate.random.randn(8, 28, 28, 3) y1 = conv1(x) @@ -807,8 +807,8 @@ def test_conv2d_stacked_layers(self): def test_conv3d_stacked_layers(self): """Test stacking multiple Conv3d layers.""" - conv1 = brainscale.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=32, kernel_size=3) - conv2 = brainscale.nn.Conv3d(in_size=(16, 16, 16, 32), out_channels=64, kernel_size=3) + conv1 = braintrace.nn.Conv3d(in_size=(16, 16, 16, 3), out_channels=32, kernel_size=3) + conv2 = braintrace.nn.Conv3d(in_size=(16, 16, 16, 32), out_channels=64, kernel_size=3) x = brainstate.random.randn(2, 16, 16, 16, 3) y1 = conv1(x) @@ -820,9 +820,9 @@ def test_conv3d_stacked_layers(self): def test_conv2d_with_pooling_like_stride(self): """Test Conv2d with stride > 1 for downsampling.""" # Create a simple CNN-like architecture with downsampling - conv1 = brainscale.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3, stride=1) - conv2 = brainscale.nn.Conv2d(in_size=(32, 32, 16), out_channels=32, kernel_size=3, stride=2) - conv3 = brainscale.nn.Conv2d(in_size=(16, 16, 32), out_channels=64, kernel_size=3, stride=2) + conv1 = braintrace.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3, stride=1) + conv2 = braintrace.nn.Conv2d(in_size=(32, 32, 16), out_channels=32, kernel_size=3, stride=2) + conv3 = braintrace.nn.Conv2d(in_size=(16, 16, 32), out_channels=64, kernel_size=3, stride=2) x = brainstate.random.randn(4, 32, 32, 3) y1 = conv1(x) @@ -835,14 +835,14 @@ def test_conv2d_with_pooling_like_stride(self): def test_conv_output_shape_computation(self): """Test that output shape is correctly computed during initialization.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) assert conv.out_size is not None assert len(conv.out_size) == 3 assert conv.out_size[-1] == 32 def test_conv_with_jit_compilation(self): """Test that convolution works with JAX JIT compilation.""" - conv = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) @brainstate.transform.jit def forward(x): @@ -856,10 +856,10 @@ def test_conv_deterministic_with_same_seed(self): """Test that convolution is deterministic with same random seed.""" # Create two identical convolutions with the same seed brainstate.random.seed(42) - conv1 = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv1 = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) brainstate.random.seed(42) - conv2 = brainscale.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) + conv2 = braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=32, kernel_size=3) x = brainstate.random.randn(8, 28, 28, 3) y1 = conv1(x) diff --git a/brainscale/nn/_linear.py b/braintrace/nn/_linear.py similarity index 95% rename from brainscale/nn/_linear.py rename to braintrace/nn/_linear.py index d0c5872..ab49bfc 100644 --- a/brainscale/nn/_linear.py +++ b/braintrace/nn/_linear.py @@ -21,9 +21,9 @@ import brainunit as u from braintools import init -from brainscale._etrace_concepts import ETraceParam -from brainscale._etrace_operators import MatMulOp, LoraOp, SpMatMulOp -from brainscale._typing import ArrayLike +from braintrace._etrace_concepts import ETraceParam +from braintrace._etrace_operators import MatMulOp, LoraOp, SpMatMulOp +from braintrace._typing import ArrayLike __all__ = [ 'Linear', @@ -73,12 +73,12 @@ class Linear(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a linear layer >>> brainstate.environ.set(precision=64) - >>> linear = brainscale.nn.Linear(in_size=128, out_size=64) + >>> linear = braintrace.nn.Linear(in_size=128, out_size=64) >>> >>> # Input with batch size 10 >>> x = brainstate.random.randn(10, 128) @@ -86,7 +86,7 @@ class Linear(brainstate.nn.Module): >>> print(y.shape) (10, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -157,13 +157,13 @@ class SignedWLinear(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a signed weight linear layer >>> brainstate.environ.set(precision=64) >>> w_sign = brainstate.random.choice([-1, 1], size=(64, 32)) - >>> linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) + >>> linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) >>> >>> # Input with batch size 5 >>> x = brainstate.random.randn(5, 64) @@ -171,7 +171,7 @@ class SignedWLinear(brainstate.nn.Module): >>> print(y.shape) (5, 32) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -244,12 +244,12 @@ class ScaledWSLinear(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a weight standardization linear layer >>> brainstate.environ.set(precision=64) - >>> linear = brainscale.nn.ScaledWSLinear(in_size=256, out_size=128, ws_gain=True, eps=1e-4) + >>> linear = braintrace.nn.ScaledWSLinear(in_size=256, out_size=128, ws_gain=True, eps=1e-4) >>> >>> # Input with batch size 16 >>> x = brainstate.random.randn(16, 256) @@ -257,7 +257,7 @@ class ScaledWSLinear(brainstate.nn.Module): >>> print(y.shape) (16, 128) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -354,7 +354,7 @@ class SparseLinear(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> import brainunit as u >>> @@ -365,7 +365,7 @@ class SparseLinear(brainstate.nn.Module): >>> sparse_mat = u.sparse.COO((values, indices), shape=(512, 256)) >>> >>> # Create a sparse linear layer - >>> linear = brainscale.nn.SparseLinear(sparse_mat, b_init=None) + >>> linear = braintrace.nn.SparseLinear(sparse_mat, b_init=None) >>> >>> # Input with batch size 8 >>> x = brainstate.random.randn(8, 512) @@ -373,7 +373,7 @@ class SparseLinear(brainstate.nn.Module): >>> print(y.shape) (8, 256) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -463,11 +463,11 @@ class LoRA(brainstate.nn.Module): .. code-block:: python >>> import brainstate - >>> import brainscale + >>> import braintrace >>> >>> # Create a standalone LoRA layer >>> brainstate.environ.set(precision=64) - >>> layer = brainscale.nn.LoRA(3, 2, 4) + >>> layer = braintrace.nn.LoRA(3, 2, 4) >>> x = brainstate.random.randn(16, 3) >>> y = layer(x) >>> print(y.shape) @@ -475,7 +475,7 @@ class LoRA(brainstate.nn.Module): >>> >>> # Wrap around existing linear layer >>> linear = brainstate.nn.Linear(3, 4) - >>> wrapper = brainscale.nn.LoRA(3, 2, 4, base_module=linear) + >>> wrapper = braintrace.nn.LoRA(3, 2, 4, base_module=linear) >>> assert wrapper.base_module == linear >>> y = wrapper(x) >>> print(y.shape) diff --git a/brainscale/nn/_linear_test.py b/braintrace/nn/_linear_test.py similarity index 83% rename from brainscale/nn/_linear_test.py rename to braintrace/nn/_linear_test.py index 3db3e5b..901ff37 100644 --- a/brainscale/nn/_linear_test.py +++ b/braintrace/nn/_linear_test.py @@ -23,7 +23,7 @@ - LoRA: Low-Rank Adaptation layer for fine-tuning """ -import brainscale +import braintrace import brainstate import brainunit as u import jax.numpy as jnp @@ -36,7 +36,7 @@ class TestLinear: def test_linear_basic_creation(self): """Test basic Linear layer creation.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64) + linear = braintrace.nn.Linear(in_size=128, out_size=64) # in_size and out_size may be scalar or sequence assert hasattr(linear, 'in_size') assert hasattr(linear, 'out_size') @@ -44,28 +44,28 @@ def test_linear_basic_creation(self): def test_linear_forward_with_batch(self): """Test Linear forward pass with batch dimension.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64) + linear = braintrace.nn.Linear(in_size=128, out_size=64) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) def test_linear_forward_without_batch(self): """Test Linear forward pass without batch dimension.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64) + linear = braintrace.nn.Linear(in_size=128, out_size=64) x = brainstate.random.randn(128) y = linear(x) assert y.shape == (64,) def test_linear_forward_multi_batch(self): """Test Linear forward pass with multiple batch dimensions.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64) + linear = braintrace.nn.Linear(in_size=128, out_size=64) x = brainstate.random.randn(5, 10, 128) y = linear(x) assert y.shape == (5, 10, 64) def test_linear_with_bias(self): """Test Linear layer with bias initialization.""" - linear = brainscale.nn.Linear( + linear = braintrace.nn.Linear( in_size=128, out_size=64, b_init=init.Constant(0.1) @@ -76,14 +76,14 @@ def test_linear_with_bias(self): def test_linear_without_bias(self): """Test Linear layer without bias.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64, b_init=None) + linear = braintrace.nn.Linear(in_size=128, out_size=64, b_init=None) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) def test_linear_custom_weight_init(self): """Test Linear layer with custom weight initializer.""" - linear = brainscale.nn.Linear( + linear = braintrace.nn.Linear( in_size=128, out_size=64, w_init=init.Constant(0.5) @@ -94,7 +94,7 @@ def test_linear_custom_weight_init(self): def test_linear_custom_bias_init(self): """Test Linear layer with custom bias initializer.""" - linear = brainscale.nn.Linear( + linear = braintrace.nn.Linear( in_size=128, out_size=64, w_init=init.KaimingNormal(), @@ -107,7 +107,7 @@ def test_linear_custom_bias_init(self): def test_linear_with_weight_mask(self): """Test Linear layer with weight mask.""" mask = jnp.ones((128, 64)) - linear = brainscale.nn.Linear(in_size=128, out_size=64, w_mask=mask) + linear = braintrace.nn.Linear(in_size=128, out_size=64, w_mask=mask) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) @@ -116,7 +116,7 @@ def test_linear_with_partial_weight_mask(self): """Test Linear layer with partial weight mask.""" mask = jnp.zeros((128, 64)) mask = mask.at[:64, :32].set(1.0) # Only connect first half to first half - linear = brainscale.nn.Linear(in_size=128, out_size=64, w_mask=mask) + linear = braintrace.nn.Linear(in_size=128, out_size=64, w_mask=mask) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) @@ -127,44 +127,44 @@ def test_linear_with_callable_weight_mask(self): def mask_fn(shape): return jnp.ones(shape) - linear = brainscale.nn.Linear(in_size=128, out_size=64, w_mask=mask_fn) + linear = braintrace.nn.Linear(in_size=128, out_size=64, w_mask=mask_fn) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) def test_linear_sequence_sizes(self): """Test Linear layer with sequence sizes.""" - linear = brainscale.nn.Linear(in_size=(64, 128), out_size=(64, 64)) + linear = braintrace.nn.Linear(in_size=(64, 128), out_size=(64, 64)) x = brainstate.random.randn(10, 128) y = linear(x) assert y.shape == (10, 64) def test_linear_large_dimensions(self): """Test Linear layer with large dimensions.""" - linear = brainscale.nn.Linear(in_size=2048, out_size=1024) + linear = braintrace.nn.Linear(in_size=2048, out_size=1024) x = brainstate.random.randn(8, 2048) y = linear(x) assert y.shape == (8, 1024) def test_linear_small_dimensions(self): """Test Linear layer with small dimensions.""" - linear = brainscale.nn.Linear(in_size=4, out_size=2) + linear = braintrace.nn.Linear(in_size=4, out_size=2) x = brainstate.random.randn(3, 4) y = linear(x) assert y.shape == (3, 2) def test_linear_with_name(self): """Test Linear layer with custom name.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64, name="test_linear") + linear = braintrace.nn.Linear(in_size=128, out_size=64, name="test_linear") assert linear.name == "test_linear" def test_linear_deterministic_with_same_seed(self): """Test that Linear is deterministic with same random seed.""" brainstate.random.seed(42) - linear1 = brainscale.nn.Linear(in_size=128, out_size=64) + linear1 = braintrace.nn.Linear(in_size=128, out_size=64) brainstate.random.seed(42) - linear2 = brainscale.nn.Linear(in_size=128, out_size=64) + linear2 = braintrace.nn.Linear(in_size=128, out_size=64) x = brainstate.random.randn(10, 128) y1 = linear1(x) @@ -178,21 +178,21 @@ class TestSignedWLinear: def test_signed_w_linear_basic_creation(self): """Test basic SignedWLinear layer creation.""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) assert hasattr(linear, 'in_size') assert hasattr(linear, 'out_size') assert hasattr(linear, 'weight_op') def test_signed_w_linear_forward_with_batch(self): """Test SignedWLinear forward pass with batch dimension.""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) def test_signed_w_linear_forward_without_batch(self): """Test SignedWLinear forward pass without batch dimension.""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) x = brainstate.random.randn(64) y = linear(x) assert y.shape == (32,) @@ -200,7 +200,7 @@ def test_signed_w_linear_forward_without_batch(self): def test_signed_w_linear_with_positive_signs(self): """Test SignedWLinear with positive sign matrix.""" w_sign = jnp.ones((64, 32)) - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) @@ -208,7 +208,7 @@ def test_signed_w_linear_with_positive_signs(self): def test_signed_w_linear_with_negative_signs(self): """Test SignedWLinear with negative sign matrix.""" w_sign = -jnp.ones((64, 32)) - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) @@ -217,21 +217,21 @@ def test_signed_w_linear_with_mixed_signs(self): """Test SignedWLinear with mixed sign matrix.""" brainstate.random.seed(123) w_sign = brainstate.random.choice(jnp.array([-1, 1]), size=(64, 32)) - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, w_sign=w_sign) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) def test_signed_w_linear_without_sign_matrix(self): """Test SignedWLinear without sign matrix (defaults to None).""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, w_sign=None) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, w_sign=None) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) def test_signed_w_linear_custom_weight_init(self): """Test SignedWLinear with custom weight initializer.""" - linear = brainscale.nn.SignedWLinear( + linear = braintrace.nn.SignedWLinear( in_size=64, out_size=32, w_init=init.Constant(0.5) @@ -242,19 +242,19 @@ def test_signed_w_linear_custom_weight_init(self): def test_signed_w_linear_sequence_sizes(self): """Test SignedWLinear with sequence sizes.""" - linear = brainscale.nn.SignedWLinear(in_size=(32, 64), out_size=(32, 32)) + linear = braintrace.nn.SignedWLinear(in_size=(32, 64), out_size=(32, 32)) x = brainstate.random.randn(5, 64) y = linear(x) assert y.shape == (5, 32) def test_signed_w_linear_with_name(self): """Test SignedWLinear with custom name.""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32, name="test_signed") + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32, name="test_signed") assert linear.name == "test_signed" def test_signed_w_linear_multi_batch(self): """Test SignedWLinear with multiple batch dimensions.""" - linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) x = brainstate.random.randn(3, 5, 64) y = linear(x) assert y.shape == (3, 5, 32) @@ -270,7 +270,7 @@ def test_sparse_linear_basic_creation_coo(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) assert hasattr(linear, 'out_size') assert hasattr(linear, 'weight_op') @@ -281,7 +281,7 @@ def test_sparse_linear_forward_with_batch(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) x = brainstate.random.randn(8, 512) y = linear(x) assert y.shape == (8, 256) @@ -293,7 +293,7 @@ def test_sparse_linear_forward_without_batch(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) x = brainstate.random.randn(512) y = linear(x) assert y.shape == (256,) @@ -305,7 +305,7 @@ def test_sparse_linear_with_bias(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat, b_init=init.Constant(0.1)) + linear = braintrace.nn.SparseLinear(sparse_mat, b_init=init.Constant(0.1)) x = brainstate.random.randn(8, 512) y = linear(x) assert y.shape == (8, 256) @@ -317,7 +317,7 @@ def test_sparse_linear_without_bias(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat, b_init=None) + linear = braintrace.nn.SparseLinear(sparse_mat, b_init=None) x = brainstate.random.randn(8, 512) y = linear(x) assert y.shape == (8, 256) @@ -329,7 +329,7 @@ def test_sparse_linear_with_in_size(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat, in_size=512) + linear = braintrace.nn.SparseLinear(sparse_mat, in_size=512) assert hasattr(linear, 'in_size') assert hasattr(linear, 'out_size') @@ -340,7 +340,7 @@ def test_sparse_linear_high_sparsity(self): values = brainstate.random.randn(100) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) x = brainstate.random.randn(8, 512) y = linear(x) assert y.shape == (8, 256) @@ -352,7 +352,7 @@ def test_sparse_linear_low_sparsity(self): values = brainstate.random.randn(50000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) x = brainstate.random.randn(8, 512) y = linear(x) assert y.shape == (8, 256) @@ -364,7 +364,7 @@ def test_sparse_linear_with_name(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat, name="test_sparse") + linear = braintrace.nn.SparseLinear(sparse_mat, name="test_sparse") assert linear.name == "test_sparse" def test_sparse_linear_multi_batch(self): @@ -374,7 +374,7 @@ def test_sparse_linear_multi_batch(self): values = brainstate.random.randn(1000) sparse_mat = u.sparse.COO((values, rows, cols), shape=(512, 256)) - linear = brainscale.nn.SparseLinear(sparse_mat) + linear = braintrace.nn.SparseLinear(sparse_mat) x = brainstate.random.randn(32, 512) y = linear(x) assert y.shape == (32, 256) @@ -384,7 +384,7 @@ def test_sparse_linear_invalid_matrix_type(self): # Pass a regular array instead of sparse matrix regular_mat = jnp.ones((512, 256)) with pytest.raises(AssertionError): - brainscale.nn.SparseLinear(regular_mat) + braintrace.nn.SparseLinear(regular_mat) class TestLoRA: @@ -392,7 +392,7 @@ class TestLoRA: def test_lora_basic_creation(self): """Test basic LoRA layer creation.""" - lora = brainscale.nn.LoRA(in_features=3, lora_rank=2, out_features=4) + lora = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) assert lora.in_features == 3 assert lora.lora_rank == 2 assert lora.out_features == 4 @@ -401,21 +401,21 @@ def test_lora_basic_creation(self): def test_lora_forward_with_batch(self): """Test LoRA forward pass with batch dimension.""" - lora = brainscale.nn.LoRA(in_features=3, lora_rank=2, out_features=4) + lora = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) x = brainstate.random.randn(16, 3) y = lora(x) assert y.shape == (16, 4) def test_lora_forward_without_batch(self): """Test LoRA forward pass without batch dimension.""" - lora = brainscale.nn.LoRA(in_features=3, lora_rank=2, out_features=4) + lora = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) x = brainstate.random.randn(3) y = lora(x) assert y.shape == (4,) def test_lora_custom_alpha(self): """Test LoRA with custom alpha value.""" - lora = brainscale.nn.LoRA(in_features=3, lora_rank=2, out_features=4, alpha=0.5) + lora = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4, alpha=0.5) assert lora.alpha == 0.5 x = brainstate.random.randn(16, 3) y = lora(x) @@ -424,7 +424,7 @@ def test_lora_custom_alpha(self): def test_lora_with_base_module(self): """Test LoRA wrapping an existing base module.""" base_linear = brainstate.nn.Linear(3, 4) - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -438,7 +438,7 @@ def test_lora_with_base_module(self): def test_lora_custom_b_init(self): """Test LoRA with custom B initializer.""" - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -450,7 +450,7 @@ def test_lora_custom_b_init(self): def test_lora_custom_a_init(self): """Test LoRA with custom A initializer.""" - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -462,7 +462,7 @@ def test_lora_custom_a_init(self): def test_lora_default_b_init_zero(self): """Test that LoRA B is initialized to zero by default.""" - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -474,7 +474,7 @@ def test_lora_default_b_init_zero(self): def test_lora_large_rank(self): """Test LoRA with large rank.""" - lora = brainscale.nn.LoRA(in_features=128, lora_rank=64, out_features=256) + lora = braintrace.nn.LoRA(in_features=128, lora_rank=64, out_features=256) assert lora.lora_rank == 64 x = brainstate.random.randn(8, 128) y = lora(x) @@ -482,7 +482,7 @@ def test_lora_large_rank(self): def test_lora_small_rank(self): """Test LoRA with small rank.""" - lora = brainscale.nn.LoRA(in_features=128, lora_rank=1, out_features=256) + lora = braintrace.nn.LoRA(in_features=128, lora_rank=1, out_features=256) assert lora.lora_rank == 1 x = brainstate.random.randn(8, 128) y = lora(x) @@ -490,14 +490,14 @@ def test_lora_small_rank(self): def test_lora_multi_batch(self): """Test LoRA with multiple batch dimensions.""" - lora = brainscale.nn.LoRA(in_features=3, lora_rank=2, out_features=4) + lora = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) x = brainstate.random.randn(4, 8, 3) y = lora(x) assert y.shape == (4, 8, 4) def test_lora_large_dimensions(self): """Test LoRA with large input/output dimensions.""" - lora = brainscale.nn.LoRA(in_features=1024, lora_rank=16, out_features=2048) + lora = braintrace.nn.LoRA(in_features=1024, lora_rank=16, out_features=2048) x = brainstate.random.randn(4, 1024) y = lora(x) assert y.shape == (4, 2048) @@ -508,7 +508,7 @@ def test_lora_with_callable_base_module(self): def custom_layer(x): return x @ jnp.ones((3, 4)) - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -523,7 +523,7 @@ def custom_layer(x): def test_lora_invalid_base_module(self): """Test LoRA raises error with invalid base module.""" with pytest.raises(AssertionError): - brainscale.nn.LoRA( + braintrace.nn.LoRA( in_features=3, lora_rank=2, out_features=4, @@ -536,9 +536,9 @@ class TestLinearIntegration: def test_linear_stacked_layers(self): """Test stacking multiple Linear layers.""" - linear1 = brainscale.nn.Linear(in_size=128, out_size=64) - linear2 = brainscale.nn.Linear(in_size=64, out_size=32) - linear3 = brainscale.nn.Linear(in_size=32, out_size=16) + linear1 = braintrace.nn.Linear(in_size=128, out_size=64) + linear2 = braintrace.nn.Linear(in_size=64, out_size=32) + linear3 = braintrace.nn.Linear(in_size=32, out_size=16) x = brainstate.random.randn(10, 128) y1 = linear1(x) @@ -551,8 +551,8 @@ def test_linear_stacked_layers(self): def test_mixed_linear_types(self): """Test mixing different types of linear layers.""" - linear1 = brainscale.nn.Linear(in_size=128, out_size=64) - signed_linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear1 = braintrace.nn.Linear(in_size=128, out_size=64) + signed_linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) x = brainstate.random.randn(10, 128) y1 = linear1(x) @@ -567,7 +567,7 @@ def test_lora_fine_tuning_scenario(self): base_linear = brainstate.nn.Linear(128, 64) # Add LoRA adaptation - lora = brainscale.nn.LoRA( + lora = braintrace.nn.LoRA( in_features=128, lora_rank=8, out_features=64, @@ -584,7 +584,7 @@ def test_lora_fine_tuning_scenario(self): def test_linear_with_jit_compilation(self): """Test that Linear works with JAX JIT compilation.""" - linear = brainscale.nn.Linear(in_size=128, out_size=64) + linear = braintrace.nn.Linear(in_size=128, out_size=64) @brainstate.transform.jit def forward(x): @@ -597,7 +597,7 @@ def forward(x): def test_sparse_linear_vs_dense_linear(self): """Test that sparse linear can approximate dense linear.""" # Create a dense linear layer - dense_linear = brainscale.nn.Linear(in_size=64, out_size=32, b_init=None) + dense_linear = braintrace.nn.Linear(in_size=64, out_size=32, b_init=None) # Create a fully connected sparse matrix (simulating dense) row_indices = jnp.repeat(jnp.arange(64), 32) @@ -605,7 +605,7 @@ def test_sparse_linear_vs_dense_linear(self): values = brainstate.random.randn(64 * 32) sparse_mat = u.sparse.COO((values, row_indices, col_indices), shape=(64, 32)) - sparse_linear = brainscale.nn.SparseLinear(sparse_mat, b_init=None) + sparse_linear = braintrace.nn.SparseLinear(sparse_mat, b_init=None) x = brainstate.random.randn(8, 64) y_dense = dense_linear(x) @@ -617,7 +617,7 @@ def test_sparse_linear_vs_dense_linear(self): def test_linear_gradient_flow(self): """Test that gradients flow through Linear layer.""" - linear = brainscale.nn.Linear(in_size=10, out_size=5) + linear = braintrace.nn.Linear(in_size=10, out_size=5) def loss_fn(x): y = linear(x) @@ -635,7 +635,7 @@ def loss_fn(x): def test_lora_without_base_module(self): """Test LoRA as standalone layer without base module.""" - lora = brainscale.nn.LoRA(in_features=64, lora_rank=8, out_features=32) + lora = braintrace.nn.LoRA(in_features=64, lora_rank=8, out_features=32) x = brainstate.random.randn(10, 64) y = lora(x) @@ -645,8 +645,8 @@ def test_lora_without_base_module(self): def test_batch_size_consistency(self): """Test that all linear layers handle different batch sizes correctly.""" - linear = brainscale.nn.Linear(in_size=64, out_size=32) - signed_linear = brainscale.nn.SignedWLinear(in_size=64, out_size=32) + linear = braintrace.nn.Linear(in_size=64, out_size=32) + signed_linear = braintrace.nn.SignedWLinear(in_size=64, out_size=32) for batch_size in [1, 8, 32, 128]: x = brainstate.random.randn(batch_size, 64) diff --git a/brainscale/nn/_normalizations.py b/braintrace/nn/_normalizations.py similarity index 94% rename from brainscale/nn/_normalizations.py rename to braintrace/nn/_normalizations.py index 0af85d3..62226ac 100644 --- a/brainscale/nn/_normalizations.py +++ b/braintrace/nn/_normalizations.py @@ -25,9 +25,9 @@ from brainstate import BatchState from brainstate.nn._normalizations import _BatchNorm -from brainscale._etrace_concepts import ETraceParam -from brainscale._etrace_operators import ETraceOp, Y, W, general_y2w -from brainscale._typing import ArrayLike, Size, Axes +from braintrace._etrace_concepts import ETraceParam +from braintrace._etrace_operators import ETraceOp, Y, W, general_y2w +from braintrace._typing import ArrayLike, Size, Axes __all__ = [ 'BatchNorm0d', @@ -58,7 +58,7 @@ def yw_to_w( class _BatchNormETrace(_BatchNorm): - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -144,11 +144,11 @@ class BatchNorm0d(_BatchNormETrace): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a BatchNorm0d layer - >>> bn = brainscale.nn.BatchNorm0d(in_size=128) + >>> bn = braintrace.nn.BatchNorm0d(in_size=128) >>> >>> # Input with batch size 32 >>> x = brainstate.random.randn(32, 128) @@ -156,7 +156,7 @@ class BatchNorm0d(_BatchNormETrace): >>> print(y.shape) (32, 128) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 0 @@ -203,11 +203,11 @@ class BatchNorm1d(_BatchNormETrace): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a BatchNorm1d layer - >>> bn = brainscale.nn.BatchNorm1d(in_size=(100, 64)) + >>> bn = braintrace.nn.BatchNorm1d(in_size=(100, 64)) >>> >>> # Input with batch size 16 >>> x = brainstate.random.randn(16, 100, 64) @@ -215,7 +215,7 @@ class BatchNorm1d(_BatchNormETrace): >>> print(y.shape) (16, 100, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 1 @@ -262,11 +262,11 @@ class BatchNorm2d(_BatchNormETrace): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a BatchNorm2d layer - >>> bn = brainscale.nn.BatchNorm2d(in_size=(28, 28, 32)) + >>> bn = braintrace.nn.BatchNorm2d(in_size=(28, 28, 32)) >>> >>> # Input with batch size 8 >>> x = brainstate.random.randn(8, 28, 28, 32) @@ -274,7 +274,7 @@ class BatchNorm2d(_BatchNormETrace): >>> print(y.shape) (8, 28, 28, 32) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 2 @@ -321,11 +321,11 @@ class BatchNorm3d(_BatchNormETrace): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a BatchNorm3d layer - >>> bn = brainscale.nn.BatchNorm3d(in_size=(16, 16, 16, 64)) + >>> bn = braintrace.nn.BatchNorm3d(in_size=(16, 16, 16, 64)) >>> >>> # Input with batch size 4 >>> x = brainstate.random.randn(4, 16, 16, 16, 64) @@ -333,7 +333,7 @@ class BatchNorm3d(_BatchNormETrace): >>> print(y.shape) (4, 16, 16, 16, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' num_spatial_dims: int = 3 @@ -366,11 +366,11 @@ class LayerNorm(brainstate.nn.LayerNorm): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a LayerNorm layer - >>> ln = brainscale.nn.LayerNorm(in_size=512) + >>> ln = braintrace.nn.LayerNorm(in_size=512) >>> >>> # Input with batch size 10 and sequence length 20 >>> x = brainstate.random.randn(10, 20, 512) @@ -378,7 +378,7 @@ class LayerNorm(brainstate.nn.LayerNorm): >>> print(y.shape) (10, 20, 512) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -422,11 +422,11 @@ class RMSNorm(brainstate.nn.RMSNorm): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create an RMSNorm layer - >>> rms = brainscale.nn.RMSNorm(in_size=768) + >>> rms = braintrace.nn.RMSNorm(in_size=768) >>> >>> # Input with batch size 8 and sequence length 128 >>> x = brainstate.random.randn(8, 128, 768) @@ -434,7 +434,7 @@ class RMSNorm(brainstate.nn.RMSNorm): >>> print(y.shape) (8, 128, 768) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -480,11 +480,11 @@ class GroupNorm(brainstate.nn.GroupNorm): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a GroupNorm layer with 8 groups for 64 channels - >>> gn = brainscale.nn.GroupNorm(num_groups=8, num_channels=64) + >>> gn = braintrace.nn.GroupNorm(num_groups=8, num_channels=64) >>> >>> # Input with batch size 4 and spatial dimensions >>> x = brainstate.random.randn(4, 32, 32, 64) @@ -492,7 +492,7 @@ class GroupNorm(brainstate.nn.GroupNorm): >>> print(y.shape) (4, 32, 32, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, diff --git a/brainscale/nn/_normalizations_test.py b/braintrace/nn/_normalizations_test.py similarity index 96% rename from brainscale/nn/_normalizations_test.py rename to braintrace/nn/_normalizations_test.py index 2fa3b20..148bcfe 100644 --- a/brainscale/nn/_normalizations_test.py +++ b/braintrace/nn/_normalizations_test.py @@ -18,7 +18,7 @@ import numpy as np import brainstate -import brainscale +import braintrace class TestBatchNorm0d(parameterized.TestCase): @@ -46,7 +46,7 @@ def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats): # affine can only be True when track_running_stats is True affine = track_running_stats - net = brainscale.nn.BatchNorm0d( + net = braintrace.nn.BatchNorm0d( in_size, feature_axis=feature_axis, track_running_stats=track_running_stats, @@ -73,7 +73,7 @@ def test_batchnorm0d_without_batch(self): channels = 10 in_size = (channels,) - net = brainscale.nn.BatchNorm0d(in_size, track_running_stats=True) + net = braintrace.nn.BatchNorm0d(in_size, track_running_stats=True) brainstate.environ.set(fit=False) # Use running stats # Run with batch first to populate running stats @@ -94,11 +94,11 @@ def test_batchnorm0d_affine(self): in_size = (channels,) # With affine - net_affine = brainscale.nn.BatchNorm0d(in_size, affine=True) + net_affine = braintrace.nn.BatchNorm0d(in_size, affine=True) self.assertIsNotNone(net_affine.weight) # Without affine (track_running_stats must be False) - net_no_affine = brainscale.nn.BatchNorm0d( + net_no_affine = braintrace.nn.BatchNorm0d( in_size, affine=False, track_running_stats=False ) self.assertIsNone(net_no_affine.weight) @@ -132,7 +132,7 @@ def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats): # affine can only be True when track_running_stats is True affine = track_running_stats - net = brainscale.nn.BatchNorm1d( + net = braintrace.nn.BatchNorm1d( in_size, feature_axis=feature_axis_param, track_running_stats=track_running_stats, @@ -152,7 +152,7 @@ def test_batchnorm1d_without_batch(self): channels = 10 in_size = (length, channels) - net = brainscale.nn.BatchNorm1d(in_size, track_running_stats=True) + net = braintrace.nn.BatchNorm1d(in_size, track_running_stats=True) # Populate running stats first brainstate.environ.set(fit=True) @@ -182,7 +182,7 @@ def test_batchnorm1d_channel_consistency(self, feature_axis): in_size = (channels, length) input_shape = (batch_size, channels, length) - net = brainscale.nn.BatchNorm1d(in_size, feature_axis=feature_axis) + net = braintrace.nn.BatchNorm1d(in_size, feature_axis=feature_axis) brainstate.environ.set(fit=True) x = brainstate.random.randn(*input_shape) @@ -220,7 +220,7 @@ def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats): # affine can only be True when track_running_stats is True affine = track_running_stats - net = brainscale.nn.BatchNorm2d( + net = braintrace.nn.BatchNorm2d( in_size, feature_axis=feature_axis_param, track_running_stats=track_running_stats, @@ -256,7 +256,7 @@ def test_batchnorm2d_without_batch(self): channels = 3 in_size = (height, width, channels) - net = brainscale.nn.BatchNorm2d(in_size, track_running_stats=True) + net = braintrace.nn.BatchNorm2d(in_size, track_running_stats=True) # Populate running stats brainstate.environ.set(fit=True) @@ -299,7 +299,7 @@ def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats): # affine can only be True when track_running_stats is True affine = track_running_stats - net = brainscale.nn.BatchNorm3d( + net = braintrace.nn.BatchNorm3d( in_size, feature_axis=feature_axis_param, track_running_stats=track_running_stats, @@ -319,7 +319,7 @@ def test_batchnorm3d_without_batch(self): channels = 2 in_size = (depth, height, width, channels) - net = brainscale.nn.BatchNorm3d(in_size, track_running_stats=True) + net = braintrace.nn.BatchNorm3d(in_size, track_running_stats=True) # Populate running stats brainstate.environ.set(fit=True) @@ -624,7 +624,7 @@ class TestNormalizationEdgeCases(parameterized.TestCase): def test_batchnorm_shape_mismatch(self): """Test that BatchNorm raises error on shape mismatch.""" - net = brainscale.nn.BatchNorm2d((28, 28, 3)) + net = braintrace.nn.BatchNorm2d((28, 28, 3)) # Wrong shape should raise error with self.assertRaises(ValueError): @@ -635,7 +635,7 @@ def test_batchnorm_without_track_and_affine(self): """Test that affine=True requires track_running_stats=True.""" # This should raise an assertion error with self.assertRaises(AssertionError): - net = brainscale.nn.BatchNorm2d( + net = braintrace.nn.BatchNorm2d( (28, 28, 3), track_running_stats=False, affine=True # Requires track_running_stats=True @@ -666,7 +666,7 @@ class TestNormalizationConsistency(parameterized.TestCase): def test_batchnorm2d_consistency_across_batches(self): """Test that BatchNorm2d behaves consistently across different batch sizes.""" in_size = (28, 28, 3) - net = brainscale.nn.BatchNorm2d(in_size, track_running_stats=True) + net = braintrace.nn.BatchNorm2d(in_size, track_running_stats=True) # Train on larger batch brainstate.environ.set(fit=True) diff --git a/brainscale/nn/_readout.py b/braintrace/nn/_readout.py similarity index 94% rename from brainscale/nn/_readout.py rename to braintrace/nn/_readout.py index 9720949..3ded42c 100644 --- a/brainscale/nn/_readout.py +++ b/braintrace/nn/_readout.py @@ -24,9 +24,9 @@ import jax import brainpy -from brainscale._etrace_concepts import ETraceParam -from brainscale._etrace_operators import MatMulOp -from brainscale._typing import Size, ArrayLike, Spike +from braintrace._etrace_concepts import ETraceParam +from braintrace._etrace_operators import MatMulOp +from braintrace._typing import Size, ArrayLike, Spike __all__ = [ 'LeakyRateReadout', @@ -43,7 +43,7 @@ class LeakyRateReadout(brainstate.nn.Module): leaky integration to the input and producing a continuous output signal. - This class is part of the BrainScale project and integrates with + This class is part of the BrainTrace project and integrates with the Brain Dynamics Programming ecosystem, providing a biologically inspired approach to neural computation. @@ -83,12 +83,12 @@ class LeakyRateReadout(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> import brainunit as u >>> >>> # Create a leaky rate readout layer - >>> readout = brainscale.nn.LeakyRateReadout( + >>> readout = braintrace.nn.LeakyRateReadout( ... in_size=256, ... out_size=10, ... tau=5.0 * u.ms @@ -101,7 +101,7 @@ class LeakyRateReadout(brainstate.nn.Module): >>> print(output.shape) (32, 10) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -149,7 +149,7 @@ class LeakySpikeReadout(brainpy.state.Neuron): based on input spikes, using specified parameters such as time constant, threshold voltage, and spike function. - This class is part of the BrainScale project and is designed to + This class is part of the BrainTrace project and is designed to integrate with the Brain Dynamics Programming ecosystem, providing a biologically inspired approach to neural computation. @@ -196,12 +196,12 @@ class LeakySpikeReadout(brainpy.state.Neuron): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> import brainunit as u >>> >>> # Create a leaky spike readout layer - >>> readout = brainscale.nn.LeakySpikeReadout( + >>> readout = braintrace.nn.LeakySpikeReadout( ... in_size=512, ... out_size=10, ... tau=10.0 * u.ms, @@ -216,7 +216,7 @@ class LeakySpikeReadout(brainpy.state.Neuron): (64, 10) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, diff --git a/brainscale/nn/_readout_test.py b/braintrace/nn/_readout_test.py similarity index 86% rename from brainscale/nn/_readout_test.py rename to braintrace/nn/_readout_test.py index d2e5c12..5dc61a8 100644 --- a/brainscale/nn/_readout_test.py +++ b/braintrace/nn/_readout_test.py @@ -22,8 +22,8 @@ """ import pytest -import brainscale -import brainscale +import braintrace +import braintrace import brainstate import braintools import brainunit as u @@ -39,7 +39,7 @@ class TestLeakyRateReadout: def test_leaky_rate_readout_basic_creation(self): """Test basic LeakyRateReadout layer creation.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) assert hasattr(readout, 'in_size') assert hasattr(readout, 'out_size') assert hasattr(readout, 'tau') @@ -48,33 +48,33 @@ def test_leaky_rate_readout_basic_creation(self): def test_leaky_rate_readout_default_tau(self): """Test LeakyRateReadout with default tau value.""" - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) assert readout.tau is not None assert readout.decay is not None def test_leaky_rate_readout_custom_tau(self): """Test LeakyRateReadout with custom tau value.""" tau_value = 10.0 * u.ms - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10, tau=tau_value) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10, tau=tau_value) assert readout.tau is not None def test_leaky_rate_readout_init_state_with_batch(self): """Test LeakyRateReadout state initialization with batch size.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) readout.init_state(batch_size=32) assert hasattr(readout, 'r') assert readout.r.value.shape == (32, 10) def test_leaky_rate_readout_init_state_without_batch(self): """Test LeakyRateReadout state initialization without batch size.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) readout.init_state(batch_size=None) assert hasattr(readout, 'r') assert readout.r.value.shape == (10,) def test_leaky_rate_readout_reset_state(self): """Test LeakyRateReadout state reset.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) readout.init_state(batch_size=32) # Modify state @@ -86,7 +86,7 @@ def test_leaky_rate_readout_reset_state(self): def test_leaky_rate_readout_forward_with_batch(self): """Test LeakyRateReadout forward pass with batch dimension.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) readout.init_state(batch_size=32) x = brainstate.random.randn(32, 256) @@ -95,7 +95,7 @@ def test_leaky_rate_readout_forward_with_batch(self): def test_leaky_rate_readout_forward_without_batch(self): """Test LeakyRateReadout forward pass without batch dimension.""" - readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=10) readout.init_state(batch_size=None) x = brainstate.random.randn(256) @@ -104,7 +104,7 @@ def test_leaky_rate_readout_forward_without_batch(self): def test_leaky_rate_readout_sequential_updates(self): """Test LeakyRateReadout with sequential updates.""" - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) outputs = [] @@ -118,7 +118,7 @@ def test_leaky_rate_readout_sequential_updates(self): def test_leaky_rate_readout_state_accumulation(self): """Test that LeakyRateReadout accumulates state over time.""" - readout = brainscale.nn.LeakyRateReadout(in_size=64, out_size=8, tau=10.0 * u.ms) + readout = braintrace.nn.LeakyRateReadout(in_size=64, out_size=8, tau=10.0 * u.ms) readout.init_state(batch_size=4) # First update @@ -134,7 +134,7 @@ def test_leaky_rate_readout_state_accumulation(self): def test_leaky_rate_readout_custom_w_init(self): """Test LeakyRateReadout with custom weight initializer.""" - readout = brainscale.nn.LeakyRateReadout( + readout = braintrace.nn.LeakyRateReadout( in_size=128, out_size=10, w_init=init.Constant(0.5) @@ -147,7 +147,7 @@ def test_leaky_rate_readout_custom_w_init(self): def test_leaky_rate_readout_custom_r_init(self): """Test LeakyRateReadout with custom state initializer.""" - readout = brainscale.nn.LeakyRateReadout( + readout = braintrace.nn.LeakyRateReadout( in_size=128, out_size=10, r_init=init.Constant(1.0) @@ -159,7 +159,7 @@ def test_leaky_rate_readout_custom_r_init(self): def test_leaky_rate_readout_with_name(self): """Test LeakyRateReadout with custom name.""" - readout = brainscale.nn.LeakyRateReadout( + readout = braintrace.nn.LeakyRateReadout( in_size=256, out_size=10, name="test_readout" @@ -168,7 +168,7 @@ def test_leaky_rate_readout_with_name(self): def test_leaky_rate_readout_large_dimensions(self): """Test LeakyRateReadout with large dimensions.""" - readout = brainscale.nn.LeakyRateReadout(in_size=2048, out_size=512) + readout = braintrace.nn.LeakyRateReadout(in_size=2048, out_size=512) readout.init_state(batch_size=8) x = brainstate.random.randn(8, 2048) @@ -177,7 +177,7 @@ def test_leaky_rate_readout_large_dimensions(self): def test_leaky_rate_readout_small_dimensions(self): """Test LeakyRateReadout with small dimensions.""" - readout = brainscale.nn.LeakyRateReadout(in_size=4, out_size=2) + readout = braintrace.nn.LeakyRateReadout(in_size=4, out_size=2) readout.init_state(batch_size=2) x = brainstate.random.randn(2, 4) @@ -186,7 +186,7 @@ def test_leaky_rate_readout_small_dimensions(self): def test_leaky_rate_readout_different_batch_sizes(self): """Test LeakyRateReadout with different batch sizes.""" - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) for batch_size in [1, 8, 32, 64]: readout.init_state(batch_size=batch_size) @@ -197,7 +197,7 @@ def test_leaky_rate_readout_different_batch_sizes(self): def test_leaky_rate_readout_decay_computation(self): """Test that decay is computed correctly from tau.""" tau_value = 5.0 * u.ms - readout = brainscale.nn.LeakyRateReadout( + readout = braintrace.nn.LeakyRateReadout( in_size=128, out_size=10, tau=tau_value @@ -210,7 +210,7 @@ def test_leaky_rate_readout_decay_computation(self): def test_leaky_rate_readout_zero_input(self): """Test LeakyRateReadout with zero input.""" - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) # Initialize with non-zero state @@ -227,11 +227,11 @@ def test_leaky_rate_readout_zero_input(self): def test_leaky_rate_readout_deterministic_with_seed(self): """Test that LeakyRateReadout is deterministic with same random seed.""" brainstate.random.seed(42) - readout1 = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout1 = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) readout1.init_state(batch_size=16) brainstate.random.seed(42) - readout2 = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout2 = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) readout2.init_state(batch_size=16) brainstate.random.seed(123) @@ -248,7 +248,7 @@ class TestLeakySpikeReadout: def test_leaky_spike_readout_basic_creation(self): """Test basic LeakySpikeReadout layer creation.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) assert hasattr(readout, 'in_size') assert hasattr(readout, 'out_size') assert hasattr(readout, 'tau') @@ -257,14 +257,14 @@ def test_leaky_spike_readout_basic_creation(self): def test_leaky_spike_readout_default_params(self): """Test LeakySpikeReadout with default parameters.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=10) assert readout.tau is not None assert readout.V_th is not None def test_leaky_spike_readout_custom_tau(self): """Test LeakySpikeReadout with custom tau value.""" tau_value = 10.0 * u.ms - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=256, out_size=10, tau=tau_value @@ -274,7 +274,7 @@ def test_leaky_spike_readout_custom_tau(self): def test_leaky_spike_readout_custom_v_th(self): """Test LeakySpikeReadout with custom threshold voltage.""" V_th_value = 2.0 * u.mV - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=256, out_size=10, V_th=V_th_value @@ -283,20 +283,20 @@ def test_leaky_spike_readout_custom_v_th(self): def test_leaky_spike_readout_init_state_with_batch(self): """Test LeakySpikeReadout state initialization with batch size.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) readout.init_state(batch_size=64) assert hasattr(readout, 'V') assert readout.V.value.shape == (64, 10) def test_leaky_spike_readout_init_state_without_batch(self): """Test LeakySpikeReadout state initialization without batch size.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) readout.init_state(batch_size=1) assert hasattr(readout, 'V') def test_leaky_spike_readout_reset_state(self): """Test LeakySpikeReadout state reset.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) readout.init_state(batch_size=64) # Modify state @@ -308,7 +308,7 @@ def test_leaky_spike_readout_reset_state(self): def test_leaky_spike_readout_forward_with_batch(self): """Test LeakySpikeReadout forward pass with batch dimension.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) readout.init_state(batch_size=64) spike_input = brainstate.random.randn(64, 512) > 0.5 @@ -317,7 +317,7 @@ def test_leaky_spike_readout_forward_with_batch(self): def test_leaky_spike_readout_forward_without_batch(self): """Test LeakySpikeReadout forward pass without explicit batch.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=512, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=512, out_size=10) readout.init_state(batch_size=1) spike_input = brainstate.random.randn(1, 512) > 0.5 @@ -326,7 +326,7 @@ def test_leaky_spike_readout_forward_without_batch(self): def test_leaky_spike_readout_sequential_updates(self): """Test LeakySpikeReadout with sequential spike updates.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=10) readout.init_state(batch_size=32) outputs = [] @@ -340,7 +340,7 @@ def test_leaky_spike_readout_sequential_updates(self): def test_leaky_spike_readout_spike_generation(self): """Test that LeakySpikeReadout generates spikes above threshold.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=128, out_size=10, tau=5.0 * u.ms, @@ -359,7 +359,7 @@ def test_leaky_spike_readout_spike_generation(self): def test_leaky_spike_readout_soft_reset(self): """Test LeakySpikeReadout with soft reset.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=128, out_size=10, spk_reset='soft' @@ -372,7 +372,7 @@ def test_leaky_spike_readout_soft_reset(self): def test_leaky_spike_readout_hard_reset(self): """Test LeakySpikeReadout with hard reset.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=128, out_size=10, spk_reset='hard' @@ -385,7 +385,7 @@ def test_leaky_spike_readout_hard_reset(self): def test_leaky_spike_readout_custom_w_init(self): """Test LeakySpikeReadout with custom weight initializer.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=256, out_size=10, w_init=init.Constant(0.5 * u.mV) @@ -398,7 +398,7 @@ def test_leaky_spike_readout_custom_w_init(self): def test_leaky_spike_readout_custom_v_init(self): """Test LeakySpikeReadout with custom voltage initializer.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=256, out_size=10, V_init=init.Constant(-0.5 * u.mV) @@ -410,7 +410,7 @@ def test_leaky_spike_readout_custom_v_init(self): def test_leaky_spike_readout_get_spike_property(self): """Test LeakySpikeReadout spike property.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) print(readout.V) @@ -422,7 +422,7 @@ def test_leaky_spike_readout_get_spike_property(self): def test_leaky_spike_readout_get_spike_method(self): """Test LeakySpikeReadout get_spike method.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) V_test = jnp.ones((16, 10)) * 2.0 * u.mV @@ -431,7 +431,7 @@ def test_leaky_spike_readout_get_spike_method(self): def test_leaky_spike_readout_with_name(self): """Test LeakySpikeReadout with custom name.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=512, out_size=10, name="test_spike_readout" @@ -440,7 +440,7 @@ def test_leaky_spike_readout_with_name(self): def test_leaky_spike_readout_large_dimensions(self): """Test LeakySpikeReadout with large dimensions.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=2048, out_size=512) + readout = braintrace.nn.LeakySpikeReadout(in_size=2048, out_size=512) readout.init_state(batch_size=8) spike_input = brainstate.random.randn(8, 2048) > 0.5 @@ -449,7 +449,7 @@ def test_leaky_spike_readout_large_dimensions(self): def test_leaky_spike_readout_small_dimensions(self): """Test LeakySpikeReadout with small dimensions.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=8, out_size=4) + readout = braintrace.nn.LeakySpikeReadout(in_size=8, out_size=4) readout.init_state(batch_size=2) spike_input = brainstate.random.randn(2, 8) > 0.5 @@ -458,7 +458,7 @@ def test_leaky_spike_readout_small_dimensions(self): def test_leaky_spike_readout_different_batch_sizes(self): """Test LeakySpikeReadout with different batch sizes.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=10) for batch_size in [1, 8, 32, 64]: readout.init_state(batch_size=batch_size) @@ -468,7 +468,7 @@ def test_leaky_spike_readout_different_batch_sizes(self): def test_leaky_spike_readout_zero_input(self): """Test LeakySpikeReadout with zero input (no spikes).""" - readout = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) # Feed zero input (no spikes) @@ -480,7 +480,7 @@ def test_leaky_spike_readout_zero_input(self): def test_leaky_spike_readout_membrane_dynamics(self): """Test that membrane potential evolves over time.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=64, out_size=8, tau=10.0 * u.ms @@ -499,11 +499,11 @@ def test_leaky_spike_readout_membrane_dynamics(self): def test_leaky_spike_readout_deterministic_with_seed(self): """Test that LeakySpikeReadout is deterministic with same random seed.""" brainstate.random.seed(42) - readout1 = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=10) + readout1 = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=10) readout1.init_state(batch_size=32) brainstate.random.seed(42) - readout2 = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=10) + readout2 = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=10) readout2.init_state(batch_size=32) brainstate.random.seed(123) @@ -516,7 +516,7 @@ def test_leaky_spike_readout_deterministic_with_seed(self): def test_leaky_spike_readout_varshape(self): """Test that LeakySpikeReadout has correct varshape.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=10) assert hasattr(readout, 'varshape') assert readout.varshape == (10,) @@ -527,11 +527,11 @@ class TestReadoutIntegration: def test_rate_and_spike_readout_together(self): """Test using both rate and spike readout in sequence.""" # Rate readout - rate_readout = brainscale.nn.LeakyRateReadout(in_size=256, out_size=128) + rate_readout = braintrace.nn.LeakyRateReadout(in_size=256, out_size=128) rate_readout.init_state(batch_size=16) # Spike readout - spike_readout = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=10) + spike_readout = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=10) spike_readout.init_state(batch_size=16) # Process through both @@ -546,9 +546,9 @@ def test_rate_and_spike_readout_together(self): def test_multiple_rate_readouts_stacked(self): """Test stacking multiple rate readout layers.""" - readout1 = brainscale.nn.LeakyRateReadout(in_size=256, out_size=128) - readout2 = brainscale.nn.LeakyRateReadout(in_size=128, out_size=64) - readout3 = brainscale.nn.LeakyRateReadout(in_size=64, out_size=10) + readout1 = braintrace.nn.LeakyRateReadout(in_size=256, out_size=128) + readout2 = braintrace.nn.LeakyRateReadout(in_size=128, out_size=64) + readout3 = braintrace.nn.LeakyRateReadout(in_size=64, out_size=10) readout1.init_state(batch_size=16) readout2.init_state(batch_size=16) @@ -565,9 +565,9 @@ def test_multiple_rate_readouts_stacked(self): def test_multiple_spike_readouts_stacked(self): """Test stacking multiple spike readout layers.""" - readout1 = brainscale.nn.LeakySpikeReadout(in_size=256, out_size=128) - readout2 = brainscale.nn.LeakySpikeReadout(in_size=128, out_size=64) - readout3 = brainscale.nn.LeakySpikeReadout(in_size=64, out_size=10) + readout1 = braintrace.nn.LeakySpikeReadout(in_size=256, out_size=128) + readout2 = braintrace.nn.LeakySpikeReadout(in_size=128, out_size=64) + readout3 = braintrace.nn.LeakySpikeReadout(in_size=64, out_size=10) readout1.init_state(batch_size=16) readout2.init_state(batch_size=16) @@ -584,7 +584,7 @@ def test_multiple_spike_readouts_stacked(self): def test_rate_readout_temporal_dynamics(self): """Test temporal dynamics of rate readout over multiple steps.""" - readout = brainscale.nn.LeakyRateReadout( + readout = braintrace.nn.LeakyRateReadout( in_size=64, out_size=10, tau=10.0 * u.ms @@ -608,7 +608,7 @@ def test_rate_readout_temporal_dynamics(self): def test_spike_readout_temporal_dynamics(self): """Test temporal dynamics of spike readout over multiple steps.""" - readout = brainscale.nn.LeakySpikeReadout( + readout = braintrace.nn.LeakySpikeReadout( in_size=64, out_size=10, tau=10.0 * u.ms @@ -627,7 +627,7 @@ def test_spike_readout_temporal_dynamics(self): def test_readout_with_jit_compilation(self): """Test that readout layers work with JAX JIT compilation.""" - readout = brainscale.nn.LeakyRateReadout(in_size=128, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=128, out_size=10) readout.init_state(batch_size=16) @brainstate.transform.jit @@ -640,7 +640,7 @@ def forward(x): def test_rate_readout_gradient_flow(self): """Test that gradients flow through rate readout layer.""" - readout = brainscale.nn.LeakyRateReadout(in_size=64, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=64, out_size=10) readout.init_state(batch_size=8) def loss_fn(x): @@ -658,7 +658,7 @@ def loss_fn(x): def test_spike_readout_gradient_flow(self): """Test that gradients flow through spike readout layer (via surrogate).""" - readout = brainscale.nn.LeakySpikeReadout(in_size=64, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=64, out_size=10) readout.init_state(batch_size=8) def loss_fn(x): @@ -675,7 +675,7 @@ def loss_fn(x): def test_rate_readout_state_persistence(self): """Test that rate readout state persists across updates.""" - readout = brainscale.nn.LeakyRateReadout(in_size=32, out_size=8) + readout = braintrace.nn.LeakyRateReadout(in_size=32, out_size=8) readout.init_state(batch_size=4) # First update @@ -693,7 +693,7 @@ def test_rate_readout_state_persistence(self): def test_spike_readout_state_persistence(self): """Test that spike readout state persists across updates.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=32, out_size=8) + readout = braintrace.nn.LeakySpikeReadout(in_size=32, out_size=8) readout.init_state(batch_size=4) # First update @@ -710,8 +710,8 @@ def test_spike_readout_state_persistence(self): def test_readout_batch_consistency(self): """Test that readout layers handle different batch sizes consistently.""" - rate_readout = brainscale.nn.LeakyRateReadout(in_size=64, out_size=10) - spike_readout = brainscale.nn.LeakySpikeReadout(in_size=64, out_size=10) + rate_readout = braintrace.nn.LeakyRateReadout(in_size=64, out_size=10) + spike_readout = braintrace.nn.LeakySpikeReadout(in_size=64, out_size=10) for batch_size in [1, 4, 16, 64]: rate_readout.init_state(batch_size=batch_size) @@ -728,7 +728,7 @@ def test_readout_batch_consistency(self): def test_rate_readout_reset_vs_reinit(self): """Test that reset_state and init_state produce same initial state.""" - readout = brainscale.nn.LeakyRateReadout(in_size=64, out_size=10) + readout = braintrace.nn.LeakyRateReadout(in_size=64, out_size=10) # Initialize readout.init_state(batch_size=16) @@ -746,7 +746,7 @@ def test_rate_readout_reset_vs_reinit(self): def test_spike_readout_reset_vs_reinit(self): """Test that reset_state and init_state produce same initial state.""" - readout = brainscale.nn.LeakySpikeReadout(in_size=64, out_size=10) + readout = braintrace.nn.LeakySpikeReadout(in_size=64, out_size=10) # Initialize readout.init_state(batch_size=16) diff --git a/brainscale/nn/_rnn.py b/braintrace/nn/_rnn.py similarity index 96% rename from brainscale/nn/_rnn.py rename to braintrace/nn/_rnn.py index 7ffc7b1..0118049 100644 --- a/brainscale/nn/_rnn.py +++ b/braintrace/nn/_rnn.py @@ -20,11 +20,11 @@ import braintools import brainunit as u -from brainscale._etrace_concepts import ( +from braintrace._etrace_concepts import ( ETraceParam, ElemWiseParam, ) -from brainscale._typing import ArrayLike +from braintrace._typing import ArrayLike from ._linear import Linear __all__ = [ @@ -69,11 +69,11 @@ class ValinaRNNCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a Vanilla RNN cell - >>> rnn_cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + >>> rnn_cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) >>> rnn_cell.init_state(batch_size=8) >>> >>> # Process a sequence of inputs @@ -82,7 +82,7 @@ class ValinaRNNCell(brainstate.nn.RNNCell): >>> print(h.shape) (8, 64) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -160,11 +160,11 @@ class GRUCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a GRU cell - >>> gru_cell = brainscale.nn.GRUCell(in_size=128, out_size=256) + >>> gru_cell = braintrace.nn.GRUCell(in_size=128, out_size=256) >>> gru_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs @@ -173,7 +173,7 @@ class GRUCell(brainstate.nn.RNNCell): >>> print(h.shape) (16, 256) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -253,11 +253,11 @@ class CFNCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a CFN cell - >>> cfn_cell = brainscale.nn.CFNCell(in_size=64, out_size=128) + >>> cfn_cell = braintrace.nn.CFNCell(in_size=64, out_size=128) >>> cfn_cell.init_state(batch_size=10) >>> >>> # Process a sequence of inputs @@ -266,7 +266,7 @@ class CFNCell(brainstate.nn.RNNCell): >>> print(h.shape) (10, 128) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -360,11 +360,11 @@ class MGUCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create an MGU cell - >>> mgu_cell = brainscale.nn.MGUCell(in_size=96, out_size=192) + >>> mgu_cell = braintrace.nn.MGUCell(in_size=96, out_size=192) >>> mgu_cell.init_state(batch_size=12) >>> >>> # Process a sequence of inputs @@ -373,7 +373,7 @@ class MGUCell(brainstate.nn.RNNCell): >>> print(h.shape) (12, 192) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -481,11 +481,11 @@ class LSTMCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create an LSTM cell - >>> lstm_cell = brainscale.nn.LSTMCell(in_size=256, out_size=512) + >>> lstm_cell = braintrace.nn.LSTMCell(in_size=256, out_size=512) >>> lstm_cell.init_state(batch_size=20) >>> >>> # Process a sequence of inputs @@ -494,7 +494,7 @@ class LSTMCell(brainstate.nn.RNNCell): >>> print(h.shape) (20, 512) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -579,11 +579,11 @@ class URLSTMCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a URLSTM cell - >>> urlstm_cell = brainscale.nn.URLSTMCell(in_size=128, out_size=256) + >>> urlstm_cell = braintrace.nn.URLSTMCell(in_size=128, out_size=256) >>> urlstm_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs @@ -592,7 +592,7 @@ class URLSTMCell(brainstate.nn.RNNCell): >>> print(h.shape) (16, 256) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -703,11 +703,11 @@ class MinimalRNNCell(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a Minimal RNN cell - >>> minrnn_cell = brainscale.nn.MinimalRNNCell(in_size=100, out_size=200) + >>> minrnn_cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) >>> minrnn_cell.init_state(batch_size=24) >>> >>> # Process a sequence of inputs @@ -716,7 +716,7 @@ class MinimalRNNCell(brainstate.nn.RNNCell): >>> print(h.shape) (24, 200) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -796,11 +796,11 @@ class MiniGRU(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a Mini GRU cell - >>> minigru_cell = brainscale.nn.MiniGRU(in_size=80, out_size=160) + >>> minigru_cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) >>> minigru_cell.init_state(batch_size=32) >>> >>> # Process a sequence of inputs @@ -809,7 +809,7 @@ class MiniGRU(brainstate.nn.RNNCell): >>> print(h.shape) (32, 160) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -884,11 +884,11 @@ class MiniLSTM(brainstate.nn.RNNCell): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create a Mini LSTM cell - >>> minilstm_cell = brainscale.nn.MiniLSTM(in_size=150, out_size=300) + >>> minilstm_cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300) >>> minilstm_cell.init_state(batch_size=40) >>> >>> # Process a sequence of inputs @@ -897,7 +897,7 @@ class MiniLSTM(brainstate.nn.RNNCell): >>> print(h.shape) (40, 300) """ - __module__ = 'brainscale.nn' + __module__ = 'braintrace.nn' def __init__( self, @@ -971,11 +971,11 @@ class LRUCell(brainstate.nn.Module): -------- .. code-block:: python - >>> import brainscale + >>> import braintrace >>> import brainstate >>> >>> # Create an LRU cell - >>> lru_cell = brainscale.nn.LRUCell(d_model=64, d_hidden=128) + >>> lru_cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128) >>> lru_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs diff --git a/brainscale/nn/_rnn_test.py b/braintrace/nn/_rnn_test.py similarity index 82% rename from brainscale/nn/_rnn_test.py rename to braintrace/nn/_rnn_test.py index 73b948b..24e8e01 100644 --- a/brainscale/nn/_rnn_test.py +++ b/braintrace/nn/_rnn_test.py @@ -35,7 +35,7 @@ u = brainunit jnp = pytest.importorskip("jax.numpy") init = braintools.init -brainscale = pytest.importorskip("brainscale") +braintrace = pytest.importorskip("braintrace") class TestValinaRNNCell: @@ -43,7 +43,7 @@ class TestValinaRNNCell: def test_valina_rnn_basic_creation(self): """Test basic ValinaRNNCell creation.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'W') @@ -51,14 +51,14 @@ def test_valina_rnn_basic_creation(self): def test_valina_rnn_init_state(self): """Test ValinaRNNCell state initialization.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) assert hasattr(cell, 'h') assert cell.h.value.shape == (8, 64) def test_valina_rnn_forward_pass(self): """Test ValinaRNNCell forward pass.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) x = brainstate.random.randn(8, 32) h = cell(x) @@ -66,7 +66,7 @@ def test_valina_rnn_forward_pass(self): def test_valina_rnn_sequential_updates(self): """Test ValinaRNNCell with sequential updates.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) outputs = [] @@ -80,7 +80,7 @@ def test_valina_rnn_sequential_updates(self): def test_valina_rnn_reset_state(self): """Test ValinaRNNCell state reset.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) # Modify state @@ -92,7 +92,7 @@ def test_valina_rnn_reset_state(self): def test_valina_rnn_custom_activation(self): """Test ValinaRNNCell with custom activation.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64, activation='tanh') + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64, activation='tanh') cell.init_state(batch_size=8) x = brainstate.random.randn(8, 32) h = cell(x) @@ -100,7 +100,7 @@ def test_valina_rnn_custom_activation(self): def test_valina_rnn_callable_activation(self): """Test ValinaRNNCell with callable activation function.""" - cell = brainscale.nn.ValinaRNNCell( + cell = braintrace.nn.ValinaRNNCell( in_size=32, out_size=64, activation=brainstate.nn.relu @@ -112,7 +112,7 @@ def test_valina_rnn_callable_activation(self): def test_valina_rnn_custom_initializers(self): """Test ValinaRNNCell with custom initializers.""" - cell = brainscale.nn.ValinaRNNCell( + cell = braintrace.nn.ValinaRNNCell( in_size=32, out_size=64, state_init=init.Constant(1.0), @@ -124,7 +124,7 @@ def test_valina_rnn_custom_initializers(self): def test_valina_rnn_with_name(self): """Test ValinaRNNCell with custom name.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64, name="test_rnn") + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64, name="test_rnn") assert cell.name == "test_rnn" @@ -133,7 +133,7 @@ class TestGRUCell: def test_gru_basic_creation(self): """Test basic GRUCell creation.""" - cell = brainscale.nn.GRUCell(in_size=128, out_size=256) + cell = braintrace.nn.GRUCell(in_size=128, out_size=256) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'Wz') @@ -142,14 +142,14 @@ def test_gru_basic_creation(self): def test_gru_init_state(self): """Test GRUCell state initialization.""" - cell = brainscale.nn.GRUCell(in_size=128, out_size=256) + cell = braintrace.nn.GRUCell(in_size=128, out_size=256) cell.init_state(batch_size=16) assert hasattr(cell, 'h') assert cell.h.value.shape == (16, 256) def test_gru_forward_pass(self): """Test GRUCell forward pass.""" - cell = brainscale.nn.GRUCell(in_size=128, out_size=256) + cell = braintrace.nn.GRUCell(in_size=128, out_size=256) cell.init_state(batch_size=16) x = brainstate.random.randn(16, 128) h = cell(x) @@ -157,7 +157,7 @@ def test_gru_forward_pass(self): def test_gru_sequential_updates(self): """Test GRUCell with sequential updates.""" - cell = brainscale.nn.GRUCell(in_size=64, out_size=128) + cell = braintrace.nn.GRUCell(in_size=64, out_size=128) cell.init_state(batch_size=8) outputs = [] @@ -171,7 +171,7 @@ def test_gru_sequential_updates(self): def test_gru_reset_state(self): """Test GRUCell state reset.""" - cell = brainscale.nn.GRUCell(in_size=64, out_size=128) + cell = braintrace.nn.GRUCell(in_size=64, out_size=128) cell.init_state(batch_size=8) cell.h.value = jnp.ones_like(cell.h.value) @@ -180,7 +180,7 @@ def test_gru_reset_state(self): def test_gru_custom_activation(self): """Test GRUCell with custom activation.""" - cell = brainscale.nn.GRUCell(in_size=64, out_size=128, activation='relu') + cell = braintrace.nn.GRUCell(in_size=64, out_size=128, activation='relu') cell.init_state(batch_size=8) x = brainstate.random.randn(8, 64) h = cell(x) @@ -188,7 +188,7 @@ def test_gru_custom_activation(self): def test_gru_with_name(self): """Test GRUCell with custom name.""" - cell = brainscale.nn.GRUCell(in_size=64, out_size=128, name="test_gru") + cell = braintrace.nn.GRUCell(in_size=64, out_size=128, name="test_gru") assert cell.name == "test_gru" @@ -197,7 +197,7 @@ class TestMGUCell: def test_mgu_basic_creation(self): """Test basic MGUCell creation.""" - cell = brainscale.nn.MGUCell(in_size=96, out_size=192) + cell = braintrace.nn.MGUCell(in_size=96, out_size=192) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'Wf') @@ -205,14 +205,14 @@ def test_mgu_basic_creation(self): def test_mgu_init_state(self): """Test MGUCell state initialization.""" - cell = brainscale.nn.MGUCell(in_size=96, out_size=192) + cell = braintrace.nn.MGUCell(in_size=96, out_size=192) cell.init_state(batch_size=12) assert hasattr(cell, 'h') assert cell.h.value.shape == (12, 192) def test_mgu_forward_pass(self): """Test MGUCell forward pass.""" - cell = brainscale.nn.MGUCell(in_size=96, out_size=192) + cell = braintrace.nn.MGUCell(in_size=96, out_size=192) cell.init_state(batch_size=12) x = brainstate.random.randn(12, 96) h = cell(x) @@ -220,7 +220,7 @@ def test_mgu_forward_pass(self): def test_mgu_sequential_updates(self): """Test MGUCell with sequential updates.""" - cell = brainscale.nn.MGUCell(in_size=48, out_size=96) + cell = braintrace.nn.MGUCell(in_size=48, out_size=96) cell.init_state(batch_size=6) outputs = [] @@ -234,7 +234,7 @@ def test_mgu_sequential_updates(self): def test_mgu_reset_state(self): """Test MGUCell state reset.""" - cell = brainscale.nn.MGUCell(in_size=48, out_size=96) + cell = braintrace.nn.MGUCell(in_size=48, out_size=96) cell.init_state(batch_size=6) cell.h.value = jnp.ones_like(cell.h.value) @@ -247,7 +247,7 @@ class TestLSTMCell: def test_lstm_basic_creation(self): """Test basic LSTMCell creation.""" - cell = brainscale.nn.LSTMCell(in_size=256, out_size=512) + cell = braintrace.nn.LSTMCell(in_size=256, out_size=512) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'Wi') @@ -257,7 +257,7 @@ def test_lstm_basic_creation(self): def test_lstm_init_state(self): """Test LSTMCell state initialization.""" - cell = brainscale.nn.LSTMCell(in_size=256, out_size=512) + cell = braintrace.nn.LSTMCell(in_size=256, out_size=512) cell.init_state(batch_size=20) assert hasattr(cell, 'h') assert hasattr(cell, 'c') @@ -266,7 +266,7 @@ def test_lstm_init_state(self): def test_lstm_forward_pass(self): """Test LSTMCell forward pass.""" - cell = brainscale.nn.LSTMCell(in_size=256, out_size=512) + cell = braintrace.nn.LSTMCell(in_size=256, out_size=512) cell.init_state(batch_size=20) x = brainstate.random.randn(20, 256) h = cell(x) @@ -274,7 +274,7 @@ def test_lstm_forward_pass(self): def test_lstm_sequential_updates(self): """Test LSTMCell with sequential updates.""" - cell = brainscale.nn.LSTMCell(in_size=128, out_size=256) + cell = braintrace.nn.LSTMCell(in_size=128, out_size=256) cell.init_state(batch_size=10) outputs = [] @@ -288,7 +288,7 @@ def test_lstm_sequential_updates(self): def test_lstm_reset_state(self): """Test LSTMCell state reset.""" - cell = brainscale.nn.LSTMCell(in_size=128, out_size=256) + cell = braintrace.nn.LSTMCell(in_size=128, out_size=256) cell.init_state(batch_size=10) cell.h.value = jnp.ones_like(cell.h.value) @@ -300,7 +300,7 @@ def test_lstm_reset_state(self): def test_lstm_custom_activation(self): """Test LSTMCell with custom activation.""" - cell = brainscale.nn.LSTMCell(in_size=64, out_size=128, activation='relu') + cell = braintrace.nn.LSTMCell(in_size=64, out_size=128, activation='relu') cell.init_state(batch_size=8) x = brainstate.random.randn(8, 64) h = cell(x) @@ -308,7 +308,7 @@ def test_lstm_custom_activation(self): def test_lstm_with_name(self): """Test LSTMCell with custom name.""" - cell = brainscale.nn.LSTMCell(in_size=64, out_size=128, name="test_lstm") + cell = braintrace.nn.LSTMCell(in_size=64, out_size=128, name="test_lstm") assert cell.name == "test_lstm" @@ -317,7 +317,7 @@ class TestMinimalRNNCell: def test_minimal_rnn_basic_creation(self): """Test basic MinimalRNNCell creation.""" - cell = brainscale.nn.MinimalRNNCell(in_size=100, out_size=200) + cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'phi') @@ -325,14 +325,14 @@ def test_minimal_rnn_basic_creation(self): def test_minimal_rnn_init_state(self): """Test MinimalRNNCell state initialization.""" - cell = brainscale.nn.MinimalRNNCell(in_size=100, out_size=200) + cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) cell.init_state(batch_size=24) assert hasattr(cell, 'h') assert cell.h.value.shape == (24, 200) def test_minimal_rnn_forward_pass(self): """Test MinimalRNNCell forward pass.""" - cell = brainscale.nn.MinimalRNNCell(in_size=100, out_size=200) + cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) cell.init_state(batch_size=24) x = brainstate.random.randn(24, 100) h = cell(x) @@ -340,7 +340,7 @@ def test_minimal_rnn_forward_pass(self): def test_minimal_rnn_sequential_updates(self): """Test MinimalRNNCell with sequential updates.""" - cell = brainscale.nn.MinimalRNNCell(in_size=50, out_size=100) + cell = braintrace.nn.MinimalRNNCell(in_size=50, out_size=100) cell.init_state(batch_size=12) outputs = [] @@ -354,7 +354,7 @@ def test_minimal_rnn_sequential_updates(self): def test_minimal_rnn_reset_state(self): """Test MinimalRNNCell state reset.""" - cell = brainscale.nn.MinimalRNNCell(in_size=50, out_size=100) + cell = braintrace.nn.MinimalRNNCell(in_size=50, out_size=100) cell.init_state(batch_size=12) cell.h.value = jnp.ones_like(cell.h.value) @@ -367,7 +367,7 @@ class TestMiniGRU: def test_minigru_basic_creation(self): """Test basic MiniGRU creation.""" - cell = brainscale.nn.MiniGRU(in_size=80, out_size=160) + cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'W_x') @@ -375,14 +375,14 @@ def test_minigru_basic_creation(self): def test_minigru_init_state(self): """Test MiniGRU state initialization.""" - cell = brainscale.nn.MiniGRU(in_size=80, out_size=160) + cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) cell.init_state(batch_size=32) assert hasattr(cell, 'h') assert cell.h.value.shape == (32, 160) def test_minigru_forward_pass(self): """Test MiniGRU forward pass.""" - cell = brainscale.nn.MiniGRU(in_size=80, out_size=160) + cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) cell.init_state(batch_size=32) x = brainstate.random.randn(32, 80) h = cell(x) @@ -390,7 +390,7 @@ def test_minigru_forward_pass(self): def test_minigru_sequential_updates(self): """Test MiniGRU with sequential updates.""" - cell = brainscale.nn.MiniGRU(in_size=40, out_size=80) + cell = braintrace.nn.MiniGRU(in_size=40, out_size=80) cell.init_state(batch_size=16) outputs = [] @@ -404,7 +404,7 @@ def test_minigru_sequential_updates(self): def test_minigru_reset_state(self): """Test MiniGRU state reset.""" - cell = brainscale.nn.MiniGRU(in_size=40, out_size=80) + cell = braintrace.nn.MiniGRU(in_size=40, out_size=80) cell.init_state(batch_size=16) cell.h.value = jnp.ones_like(cell.h.value) @@ -417,7 +417,7 @@ class TestMiniLSTM: def test_minilstm_basic_creation(self): """Test basic MiniLSTM creation.""" - cell = brainscale.nn.MiniLSTM(in_size=150, out_size=300) + cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300) assert hasattr(cell, 'in_size') assert hasattr(cell, 'out_size') assert hasattr(cell, 'W_x') @@ -426,14 +426,14 @@ def test_minilstm_basic_creation(self): def test_minilstm_init_state(self): """Test MiniLSTM state initialization.""" - cell = brainscale.nn.MiniLSTM(in_size=150, out_size=300) + cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300) cell.init_state(batch_size=40) assert hasattr(cell, 'h') assert cell.h.value.shape == (40, 300) def test_minilstm_forward_pass(self): """Test MiniLSTM forward pass.""" - cell = brainscale.nn.MiniLSTM(in_size=150, out_size=300) + cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300) cell.init_state(batch_size=40) x = brainstate.random.randn(40, 150) h = cell(x) @@ -441,7 +441,7 @@ def test_minilstm_forward_pass(self): def test_minilstm_sequential_updates(self): """Test MiniLSTM with sequential updates.""" - cell = brainscale.nn.MiniLSTM(in_size=75, out_size=150) + cell = braintrace.nn.MiniLSTM(in_size=75, out_size=150) cell.init_state(batch_size=20) outputs = [] @@ -455,7 +455,7 @@ def test_minilstm_sequential_updates(self): def test_minilstm_reset_state(self): """Test MiniLSTM state reset.""" - cell = brainscale.nn.MiniLSTM(in_size=75, out_size=150) + cell = braintrace.nn.MiniLSTM(in_size=75, out_size=150) cell.init_state(batch_size=20) cell.h.value = jnp.ones_like(cell.h.value) @@ -468,7 +468,7 @@ class TestLRUCell: def test_lru_basic_creation(self): """Test basic LRUCell creation.""" - cell = brainscale.nn.LRUCell(d_model=64, d_hidden=128) + cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128) assert hasattr(cell, 'd_model') assert hasattr(cell, 'd_hidden') assert hasattr(cell, 'B_re') @@ -478,7 +478,7 @@ def test_lru_basic_creation(self): def test_lru_init_state(self): """Test LRUCell state initialization.""" - cell = brainscale.nn.LRUCell(d_model=64, d_hidden=128) + cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128) cell.init_state(batch_size=16) assert hasattr(cell, 'h_re') assert hasattr(cell, 'h_im') @@ -487,7 +487,7 @@ def test_lru_init_state(self): def test_lru_forward_pass(self): """Test LRUCell forward pass.""" - cell = brainscale.nn.LRUCell(d_model=64, d_hidden=128) + cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128) cell.init_state(batch_size=16) x = brainstate.random.randn(16, 64) y = cell(x) @@ -495,7 +495,7 @@ def test_lru_forward_pass(self): def test_lru_sequential_updates(self): """Test LRUCell with sequential updates.""" - cell = brainscale.nn.LRUCell(d_model=32, d_hidden=64) + cell = braintrace.nn.LRUCell(d_model=32, d_hidden=64) cell.init_state(batch_size=8) outputs = [] @@ -509,7 +509,7 @@ def test_lru_sequential_updates(self): def test_lru_reset_state(self): """Test LRUCell state reset.""" - cell = brainscale.nn.LRUCell(d_model=32, d_hidden=64) + cell = braintrace.nn.LRUCell(d_model=32, d_hidden=64) cell.init_state(batch_size=8) cell.h_re.value = jnp.ones_like(cell.h_re.value) @@ -521,7 +521,7 @@ def test_lru_reset_state(self): def test_lru_custom_parameters(self): """Test LRUCell with custom parameters.""" - cell = brainscale.nn.LRUCell( + cell = braintrace.nn.LRUCell( d_model=64, d_hidden=128, r_min=0.1, @@ -538,7 +538,7 @@ class TestRNNCellIntegration: def test_different_batch_sizes(self): """Test RNN cells with different batch sizes.""" - cell = brainscale.nn.GRUCell(in_size=32, out_size=64) + cell = braintrace.nn.GRUCell(in_size=32, out_size=64) for batch_size in [1, 4, 16, 32]: cell.init_state(batch_size=batch_size) @@ -548,7 +548,7 @@ def test_different_batch_sizes(self): def test_sequence_processing(self): """Test processing a full sequence through RNN cell.""" - cell = brainscale.nn.LSTMCell(in_size=64, out_size=128) + cell = braintrace.nn.LSTMCell(in_size=64, out_size=128) cell.init_state(batch_size=8) sequence_length = 50 @@ -564,7 +564,7 @@ def test_sequence_processing(self): def test_state_persistence(self): """Test that state persists across updates.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=4) # First update @@ -581,7 +581,7 @@ def test_state_persistence(self): def test_gradient_flow(self): """Test that gradients flow through RNN cell.""" - cell = brainscale.nn.GRUCell(in_size=32, out_size=64) + cell = braintrace.nn.GRUCell(in_size=32, out_size=64) cell.init_state(batch_size=4) def loss_fn(x): @@ -597,7 +597,7 @@ def loss_fn(x): def test_jit_compilation(self): """Test RNN cell with JIT compilation.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) @brainstate.transform.jit @@ -611,11 +611,11 @@ def forward(x): def test_deterministic_with_seed(self): """Test RNN cell determinism with same seed.""" brainstate.random.seed(42) - cell1 = brainscale.nn.GRUCell(in_size=32, out_size=64) + cell1 = braintrace.nn.GRUCell(in_size=32, out_size=64) cell1.init_state(batch_size=8) brainstate.random.seed(42) - cell2 = brainscale.nn.GRUCell(in_size=32, out_size=64) + cell2 = braintrace.nn.GRUCell(in_size=32, out_size=64) cell2.init_state(batch_size=8) brainstate.random.seed(123) @@ -628,8 +628,8 @@ def test_deterministic_with_seed(self): def test_mixed_cell_types(self): """Test using different cell types in sequence.""" - gru_cell = brainscale.nn.GRUCell(in_size=64, out_size=128) - lstm_cell = brainscale.nn.LSTMCell(in_size=128, out_size=256) + gru_cell = braintrace.nn.GRUCell(in_size=64, out_size=128) + lstm_cell = braintrace.nn.LSTMCell(in_size=128, out_size=256) gru_cell.init_state(batch_size=8) lstm_cell.init_state(batch_size=8) @@ -643,7 +643,7 @@ def test_mixed_cell_types(self): def test_large_dimensions(self): """Test RNN cells with large dimensions.""" - cell = brainscale.nn.LSTMCell(in_size=1024, out_size=2048) + cell = braintrace.nn.LSTMCell(in_size=1024, out_size=2048) cell.init_state(batch_size=4) x = brainstate.random.randn(4, 1024) @@ -652,7 +652,7 @@ def test_large_dimensions(self): def test_small_dimensions(self): """Test RNN cells with small dimensions.""" - cell = brainscale.nn.GRUCell(in_size=4, out_size=8) + cell = braintrace.nn.GRUCell(in_size=4, out_size=8) cell.init_state(batch_size=2) x = brainstate.random.randn(2, 4) @@ -661,7 +661,7 @@ def test_small_dimensions(self): def test_zero_input(self): """Test RNN cell with zero input.""" - cell = brainscale.nn.ValinaRNNCell(in_size=32, out_size=64) + cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64) cell.init_state(batch_size=8) # Initialize with non-zero state @@ -676,7 +676,7 @@ def test_zero_input(self): def test_reset_during_sequence(self): """Test resetting state during sequence processing.""" - cell = brainscale.nn.GRUCell(in_size=32, out_size=64) + cell = braintrace.nn.GRUCell(in_size=32, out_size=64) cell.init_state(batch_size=8) # Process first part of sequence @@ -696,10 +696,10 @@ def test_reset_during_sequence(self): def test_batch_size_consistency(self): """Test consistency across different batch sizes.""" cells = [ - brainscale.nn.ValinaRNNCell(in_size=32, out_size=64), - brainscale.nn.GRUCell(in_size=32, out_size=64), - brainscale.nn.MGUCell(in_size=32, out_size=64), - brainscale.nn.LSTMCell(in_size=32, out_size=64), + braintrace.nn.ValinaRNNCell(in_size=32, out_size=64), + braintrace.nn.GRUCell(in_size=32, out_size=64), + braintrace.nn.MGUCell(in_size=32, out_size=64), + braintrace.nn.LSTMCell(in_size=32, out_size=64), ] for batch_size in [1, 4, 8, 16]: diff --git a/docs/advanced/IR_analysis-en.ipynb b/docs/advanced/IR_analysis-en.ipynb index 7de5f03..fb084ef 100644 --- a/docs/advanced/IR_analysis-en.ipynb +++ b/docs/advanced/IR_analysis-en.ipynb @@ -6,9 +6,9 @@ "source": [ "# IR Analysis and Optimization\n", "\n", - "The core capability of the brainscale framework lies in converting user-defined neural network models into efficient Intermediate Representation (IR) and performing deep analysis and optimization based on this representation. By extracting information flow and dependency relationships between key components of the model, brainscale can generate efficient online learning training code, enabling fast training and inference of neural networks.\n", + "The core capability of the braintrace framework lies in converting user-defined neural network models into efficient Intermediate Representation (IR) and performing deep analysis and optimization based on this representation. By extracting information flow and dependency relationships between key components of the model, braintrace can generate efficient online learning training code, enabling fast training and inference of neural networks.\n", "\n", - "This guide provides a comprehensive introduction to the IR analysis and optimization process in brainscale, including:\n", + "This guide provides a comprehensive introduction to the IR analysis and optimization process in braintrace, including:\n", "- **Model Information Extraction**: Obtaining complete structural information of the model\n", "- **State Group Analysis**: Identifying sets of interdependent state variables\n", "- **Parameter-State Relationships**: Analyzing connection relationships between parameters and hidden states\n", @@ -29,7 +29,7 @@ }, "cell_type": "code", "source": [ - "import brainscale\n", + "import braintrace\n", "import brainstate\n", "import brainunit as u\n", "\n", @@ -44,11 +44,11 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "We use the ALIF model + STP synapse model from brainscale for testing as an example. This model is defined in `brainscale._etrace_model_test`.\n", + "We use the ALIF model + STP synapse model from braintrace for testing as an example. This model is defined in `braintrace._etrace_model_test`.\n", "\n", "### Example Model\n", "\n", - "We use brainscale's built-in ALIF (Adaptive Leaky Integrate-and-Fire) neuron model combined with STP (Short-Term Plasticity) synapse model as a demonstration case:" + "We use braintrace's built-in ALIF (Adaptive Leaky Integrate-and-Fire) neuron model combined with STP (Short-Term Plasticity) synapse model as a demonstration case:" ], "id": "e7dff82f20bcfd18" }, @@ -61,7 +61,7 @@ }, "cell_type": "code", "source": [ - "from brainscale._etrace_model_test import ALIF_STPExpCu_Dense_Layer\n", + "from braintrace._etrace_model_test import ALIF_STPExpCu_Dense_Layer\n", "\n", "# Create network instance\n", "n_in = 3 # Input dimension\n", @@ -85,7 +85,7 @@ "\n", "### 1.1 What is ModuleInfo\n", "\n", - "`ModuleInfo` is brainscale's complete description of neural network models, containing all key information of the model:\n", + "`ModuleInfo` is braintrace's complete description of neural network models, containing all key information of the model:\n", "- **Input/Output Interface**: Data flow entry and exit points of the model\n", "- **State Variables**: Dynamic variables such as neuron states and synaptic states\n", "- **Parameter Variables**: Trainable parameters such as weights and biases\n", @@ -105,7 +105,7 @@ "cell_type": "code", "source": [ "# Extract complete model information\n", - "info = brainscale.extract_module_info(net, input_data)\n", + "info = braintrace.extract_module_info(net, input_data)\n", "print(\"Model information extraction completed\")\n", "print(f\"Number of hidden states: {len(info.hidden_path_to_invar)}\")\n", "print(f\"Number of compiled model states: {len(info.compiled_model_states)}\")" @@ -213,7 +213,7 @@ "\n", "### 2.1 Concept of State Groups\n", "\n", - "State Groups (HiddenGroup) are an important concept in brainscale that organize state variables in the model with the following characteristics:\n", + "State Groups (HiddenGroup) are an important concept in braintrace that organize state variables in the model with the following characteristics:\n", "- **Interdependence**: Direct or indirect dependency relationships exist between state variables\n", "- **Element-wise Operations**: Operations between variables are primarily element-wise mathematical operations\n", "- **Synchronized Updates**: These states need coordinated updates within time steps\n", @@ -232,7 +232,7 @@ "cell_type": "code", "source": [ "# Extract state groups from ModuleInfo\n", - "hidden_groups, hid_path_to_group = brainscale.find_hidden_groups_from_minfo(info)\n", + "hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_minfo(info)\n", "\n", "print(f\"Number of state groups discovered: {len(hidden_groups)}\")\n", "print(f\"Number of state path to group mappings: {len(hid_path_to_group)}\")" @@ -337,7 +337,7 @@ "cell_type": "code", "source": [ "# Extract parameter-state relationships\n", - "hidden_param_op = brainscale.find_hidden_param_op_relations_from_minfo(info, hid_path_to_group)\n", + "hidden_param_op = braintrace.find_hidden_param_op_relations_from_minfo(info, hid_path_to_group)\n", "\n", "print(f\"Number of parameter-state relationships discovered: {len(hidden_param_op)}\")" ], @@ -408,7 +408,7 @@ "source": [ "### 3.4 Gradient Computation Optimization\n", "\n", - "Analysis of parameter-state relationships enables brainscale to:\n", + "Analysis of parameter-state relationships enables braintrace to:\n", "- **Precise Tracking**: Determine which states are affected by parameter updates\n", "- **Efficient Computation**: Only compute necessary gradient paths\n", "- **Memory Savings**: Avoid storing unnecessary intermediate gradients\n", @@ -417,7 +417,7 @@ "\n", "### 4.1 Perturbation Mechanism Principles\n", "\n", - "State perturbation is the core technology of brainscale for implementing efficient gradient computation. By applying small perturbations to hidden states $y = f(x, h+\\Delta)$, where $\\Delta \\to 0$, the system can:\n", + "State perturbation is the core technology of braintrace for implementing efficient gradient computation. By applying small perturbations to hidden states $y = f(x, h+\\Delta)$, where $\\Delta \\to 0$, the system can:\n", "- **Numerical Gradients**: Compute gradients through finite differences\n", "- **Automatic Differentiation**: Build efficient backpropagation graphs\n", "\n", @@ -435,7 +435,7 @@ "cell_type": "code", "source": [ "# Extract state perturbation information\n", - "hidden_perturb = brainscale.add_hidden_perturbation_from_minfo(info)\n", + "hidden_perturb = braintrace.add_hidden_perturbation_from_minfo(info)\n", "\n", "print(\"=== State Perturbation Information ===\")\n", "print(f\"Number of perturbation variables: {len(hidden_perturb.perturb_vars)}\")\n", @@ -530,14 +530,14 @@ "\n", "## Summary\n", "\n", - "This guide provides a comprehensive introduction to the core concepts and practical methods of IR analysis and optimization in the brainscale framework. Through deep understanding of model information extraction, state group analysis, parameter relationship analysis, and state perturbation mechanisms, developers can:\n", + "This guide provides a comprehensive introduction to the core concepts and practical methods of IR analysis and optimization in the braintrace framework. Through deep understanding of model information extraction, state group analysis, parameter relationship analysis, and state perturbation mechanisms, developers can:\n", "\n", "- **Improve Performance**: Significantly enhance model training and inference efficiency through IR optimization\n", "- **Enhance Understanding**: Gain deep insights into the internal structure and computational flow of neural network models\n", "- **Optimize Design**: Design more efficient model architectures based on analysis results\n", "- **Debugging Capabilities**: Quickly locate and solve problems in model training\n", "\n", - "The IR analysis capabilities of brainscale provide powerful tool support for efficient implementation of neural networks and serve as an important foundation for achieving high-performance online learning. Users can now customize IR analysis and optimization processes to meet specific application requirements." + "The IR analysis capabilities of braintrace provide powerful tool support for efficient implementation of neural networks and serve as an important foundation for achieving high-performance online learning. Users can now customize IR analysis and optimization processes to meet specific application requirements." ], "id": "2a75e72100b8fa89" } diff --git a/docs/advanced/IR_analysis-zh.ipynb b/docs/advanced/IR_analysis-zh.ipynb index e39aea0..5503a21 100644 --- a/docs/advanced/IR_analysis-zh.ipynb +++ b/docs/advanced/IR_analysis-zh.ipynb @@ -6,9 +6,9 @@ "source": [ "# IR分析和优化\n", "\n", - "brainscale框架的核心能力在于将用户定义的神经网络模型转换为高效的中间表示(Intermediate Representation, IR),并在此基础上进行深度分析和优化。通过提取模型关键组成部分之间的信息流和依赖关系,brainscale能够生成高效的在线学习训练代码,实现神经网络的快速训练和推理。\n", + "braintrace框架的核心能力在于将用户定义的神经网络模型转换为高效的中间表示(Intermediate Representation, IR),并在此基础上进行深度分析和优化。通过提取模型关键组成部分之间的信息流和依赖关系,braintrace能够生成高效的在线学习训练代码,实现神经网络的快速训练和推理。\n", "\n", - "本指南将全面介绍brainscale中的IR分析和优化流程,包括:\n", + "本指南将全面介绍braintrace中的IR分析和优化流程,包括:\n", "- **模型信息提取**:获取模型的完整结构信息\n", "- **状态群分析**:识别相互依赖的状态变量集合\n", "- **参数-状态关系**:分析参数与隐藏状态的连接关系\n", @@ -29,7 +29,7 @@ }, "cell_type": "code", "source": [ - "import brainscale\n", + "import braintrace\n", "import brainstate\n", "import brainunit as u\n", "from pprint import pprint\n", @@ -44,7 +44,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "我们使用brainscale中用于测试的ALIF模型+STP突触模型作为示例。该模型的定义在`brainscale._etrace_model_test`中。", + "source": "我们使用braintrace中用于测试的ALIF模型+STP突触模型作为示例。该模型的定义在`braintrace._etrace_model_test`中。", "id": "608fb5690fc09df1" }, { @@ -53,7 +53,7 @@ "source": [ "### 示例模型\n", "\n", - "我们使用brainscale内置的ALIF(Adaptive Leaky Integrate-and-Fire)神经元模型结合STP(Short-Term Plasticity)突触模型作为演示案例:" + "我们使用braintrace内置的ALIF(Adaptive Leaky Integrate-and-Fire)神经元模型结合STP(Short-Term Plasticity)突触模型作为演示案例:" ], "id": "df5ff0c796a4de76" }, @@ -66,7 +66,7 @@ }, "cell_type": "code", "source": [ - "from brainscale._etrace_model_test import ALIF_STPExpCu_Dense_Layer\n", + "from braintrace._etrace_model_test import ALIF_STPExpCu_Dense_Layer\n", "\n", "# 创建网络实例\n", "n_in = 3 # 输入维度\n", @@ -90,7 +90,7 @@ "\n", "### 1.1 什么是ModuleInfo\n", "\n", - "`ModuleInfo`是brainscale对神经网络模型的完整描述,包含了模型的所有关键信息:\n", + "`ModuleInfo`是braintrace对神经网络模型的完整描述,包含了模型的所有关键信息:\n", "- **输入输出接口**:模型的数据流入和流出点\n", "- **状态变量**:神经元状态、突触状态等动态变量\n", "- **参数变量**:权重、偏置等可训练参数\n", @@ -110,7 +110,7 @@ "cell_type": "code", "source": [ "# 提取模型的完整信息\n", - "info = brainscale.extract_module_info(net, input_data)\n", + "info = braintrace.extract_module_info(net, input_data)\n", "print(\"模型信息提取完成\")\n", "print(f\"隐藏状态数量: {len(info.hidden_path_to_invar)}\")\n", "print(f\"编译后状态数量: {len(info.compiled_model_states)}\")" @@ -218,7 +218,7 @@ "\n", "### 2.1 状态群的概念\n", "\n", - "状态群(HiddenGroup)是brainscale中的重要概念,它将模型中具有以下特征的状态变量组织在一起:\n", + "状态群(HiddenGroup)是braintrace中的重要概念,它将模型中具有以下特征的状态变量组织在一起:\n", "- **相互依赖**:状态变量之间存在直接或间接的依赖关系\n", "- **逐元素运算**:变量间的操作主要是逐元素的数学运算\n", "- **同步更新**:这些状态在时间步内需要协调更新\n", @@ -237,7 +237,7 @@ "cell_type": "code", "source": [ "# 从ModuleInfo中提取状态群\n", - "hidden_groups, hid_path_to_group = brainscale.find_hidden_groups_from_minfo(info)\n", + "hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_minfo(info)\n", "\n", "print(f\"发现状态群数量: {len(hidden_groups)}\")\n", "print(f\"状态路径到群组的映射数量: {len(hid_path_to_group)}\")" @@ -348,7 +348,7 @@ "cell_type": "code", "source": [ "# 提取参数-状态关系\n", - "hidden_param_op = brainscale.find_hidden_param_op_relations_from_minfo(info, hid_path_to_group)\n", + "hidden_param_op = braintrace.find_hidden_param_op_relations_from_minfo(info, hid_path_to_group)\n", "\n", "print(f\"发现参数-状态关系数量: {len(hidden_param_op)}\")" ], @@ -419,7 +419,7 @@ "source": [ "### 3.4 梯度计算优化\n", "\n", - "参数-状态关系的分析使得brainscale能够:\n", + "参数-状态关系的分析使得braintrace能够:\n", "- **精确追踪**:确定参数更新对哪些状态产生影响\n", "- **高效计算**:只计算必要的梯度路径\n", "- **内存节省**:避免存储不必要的中间梯度\n" @@ -434,7 +434,7 @@ "\n", "### 4.1 扰动机制原理\n", "\n", - "状态扰动是brainscale实现高效梯度计算的核心技术。通过对隐藏状态施加微小扰动 $y = f(x, h+\\Delta)$,其中$\\Delta \\to 0$,系统能够:\n", + "状态扰动是braintrace实现高效梯度计算的核心技术。通过对隐藏状态施加微小扰动 $y = f(x, h+\\Delta)$,其中$\\Delta \\to 0$,系统能够:\n", "- **数值梯度**:通过有限差分计算梯度\n", "- **自动微分**:构建高效的反向传播图\n", "\n", @@ -452,7 +452,7 @@ "cell_type": "code", "source": [ "# 提取状态扰动信息\n", - "hidden_perturb = brainscale.add_hidden_perturbation_from_minfo(info)\n", + "hidden_perturb = braintrace.add_hidden_perturbation_from_minfo(info)\n", "\n", "print(\"=== 状态扰动信息 ===\")\n", "print(f\"扰动变量数量: {len(hidden_perturb.perturb_vars)}\")\n", @@ -551,14 +551,14 @@ "source": [ "## 总结\n", "\n", - "本指南全面介绍了brainscale框架中IR分析和优化的核心概念和实践方法。通过深入理解模型信息提取、状态群分析、参数关系分析和状态扰动机制,开发者能够:\n", + "本指南全面介绍了braintrace框架中IR分析和优化的核心概念和实践方法。通过深入理解模型信息提取、状态群分析、参数关系分析和状态扰动机制,开发者能够:\n", "\n", "- **提升性能**:通过IR优化显著提高模型训练和推理效率\n", "- **增强理解**:深入了解神经网络模型的内部结构和计算流程\n", "- **优化设计**:基于分析结果设计更高效的模型架构\n", "- **调试能力**:快速定位和解决模型训练中的问题\n", "\n", - "brainscale的IR分析能力为神经网络的高效实现提供了强大的工具支持,是实现高性能在线学习的重要基础。用户已可以自定义IR分析和优化流程,以满足特定应用需求。" + "braintrace的IR分析能力为神经网络的高效实现提供了强大的工具支持,是实现高性能在线学习的重要基础。用户已可以自定义IR分析和优化流程,以满足特定应用需求。" ], "id": "6a5d217380009df3" } diff --git a/docs/advanced/limitations-en.ipynb b/docs/advanced/limitations-en.ipynb index d168af4..986391e 100644 --- a/docs/advanced/limitations-en.ipynb +++ b/docs/advanced/limitations-en.ipynb @@ -6,9 +6,9 @@ "source": [ "# Limitations\n", "\n", - "While **brainscale** provides powerful compilation tools capable of recognizing most JAX operators, several limitations remain. These limitations primarily stem from brainscale's inability to automatically infer the relationships between model states and parameters when processing conditional control flow and function transformations, preventing successful compilation into eligibility traces.\n", + "While **braintrace** provides powerful compilation tools capable of recognizing most JAX operators, several limitations remain. These limitations primarily stem from braintrace's inability to automatically infer the relationships between model states and parameters when processing conditional control flow and function transformations, preventing successful compilation into eligibility traces.\n", "\n", - "Currently, brainscale's **ETraceState**, **ETraceOp**, and **ETraceParam** constructs do not support the following JAX primitives:\n", + "Currently, braintrace's **ETraceState**, **ETraceOp**, and **ETraceParam** constructs do not support the following JAX primitives:\n", "\n", "- `jax.lax.cond`\n", "- `jax.lax.scan`\n", @@ -18,7 +18,7 @@ "\n", "We are actively working to address these limitations and plan to support these features in future releases.\n", "\n", - "**Important note:** These JAX primitives can still be used and compiled successfully by brainscale when they appear outside of eligibility trace constructs (i.e., when not wrapped within **ETraceState**, **ETraceOp**, or **ETraceParam** contexts).\n" + "**Important note:** These JAX primitives can still be used and compiled successfully by braintrace when they appear outside of eligibility trace constructs (i.e., when not wrapped within **ETraceState**, **ETraceOp**, or **ETraceParam** contexts).\n" ], "id": "2c39803f000da885" } diff --git a/docs/advanced/limitations-zh.ipynb b/docs/advanced/limitations-zh.ipynb index 0694a4f..837676c 100644 --- a/docs/advanced/limitations-zh.ipynb +++ b/docs/advanced/limitations-zh.ipynb @@ -6,9 +6,9 @@ "source": [ "# 局限性\n", "\n", - "尽管 ``brainscale`` 提供了强大的编译工具,能够识别大多数 JAX 算子,但仍然存在一些局限性。这些局限性的主要原因在于 ``brainscale`` 在处理条件控制和功能增强函数时,无法有效自动识别模型状态与参数之间的关系,导致无法编译为 eligibility trace。\n", + "尽管 ``braintrace`` 提供了强大的编译工具,能够识别大多数 JAX 算子,但仍然存在一些局限性。这些局限性的主要原因在于 ``braintrace`` 在处理条件控制和功能增强函数时,无法有效自动识别模型状态与参数之间的关系,导致无法编译为 eligibility trace。\n", "\n", - "目前,``brainscale``中的 ``ETraceState``、``ETraceOp``、 ``ETraceParam``等语法不支持以下 JAX 原语:\n", + "目前,``braintrace``中的 ``ETraceState``、``ETraceOp``、 ``ETraceParam``等语法不支持以下 JAX 原语:\n", "\n", "- `jax.lax.cond`\n", "- `jax.lax.scan`\n", @@ -18,7 +18,7 @@ "\n", "我们正在积极努力解决这些问题,以便在未来的版本中支持这些功能。\n", "\n", - "但是,如果在非 ``ETraceState``、``ETraceOp``、 ``ETraceParam``等语法的情况下使用上述 JAX 原语,``brainscale`` 仍然可以正常编译工作。\n" + "但是,如果在非 ``ETraceState``、``ETraceOp``、 ``ETraceParam``等语法的情况下使用上述 JAX 原语,``braintrace`` 仍然可以正常编译工作。\n" ], "id": "2c39803f000da885" } diff --git a/docs/apis/algorithms.rst b/docs/apis/algorithms.rst index fdde81a..c180727 100644 --- a/docs/apis/algorithms.rst +++ b/docs/apis/algorithms.rst @@ -1,7 +1,7 @@ Online Learning Algorithms ========================== -.. currentmodule:: brainscale +.. currentmodule:: braintrace .. contents:: :local: diff --git a/docs/apis/compiler.rst b/docs/apis/compiler.rst index 53a948b..eca633f 100644 --- a/docs/apis/compiler.rst +++ b/docs/apis/compiler.rst @@ -1,7 +1,7 @@ Online Learning Compiler ======================== -.. currentmodule:: brainscale +.. currentmodule:: braintrace .. contents:: :local: diff --git a/docs/apis/concepts.rst b/docs/apis/concepts.rst index 4e45d98..b111e53 100644 --- a/docs/apis/concepts.rst +++ b/docs/apis/concepts.rst @@ -1,7 +1,7 @@ Online Learning Concepts ======================== -.. currentmodule:: brainscale +.. currentmodule:: braintrace .. contents:: :local: diff --git a/docs/apis/nn.rst b/docs/apis/nn.rst index 7863cc1..611a579 100644 --- a/docs/apis/nn.rst +++ b/docs/apis/nn.rst @@ -1,8 +1,8 @@ -``brainscale.nn`` for neural network building +``braintrace.nn`` for neural network building ============================================= -.. currentmodule:: brainscale.nn -.. automodule:: brainscale.nn +.. currentmodule:: braintrace.nn +.. automodule:: braintrace.nn diff --git a/docs/conf.py b/docs/conf.py index 5b086f4..e0aa701 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,19 +35,19 @@ sys.path.insert(0, os.path.abspath('../')) -import brainscale +import braintrace import shutil shutil.copy('../changelog.md', './changelog.md') # -- Project information ----------------------------------------------------- -project = 'BrainScale' -copyright = '2024, BrainScale' -author = 'BrainScale Developer' +project = 'BrainTrace' +copyright = '2024, BrainTrace' +author = 'BrainTrace Developer' # The full version, including alpha/beta/rc tags -release = brainscale.__version__ +release = braintrace.__version__ # -- General configuration --------------------------------------------------- @@ -113,11 +113,11 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] html_theme = "sphinx_book_theme" -html_logo = "_static/brainscale.png" -html_title = "BrainScale" +html_logo = "_static/braintrace.png" +html_title = "BrainTrace" html_copy_source = True html_sourcelink_suffix = "" -html_favicon = "_static/brainscale.png" +html_favicon = "_static/braintrace.png" html_last_updated_fmt = "" # Add any paths that contain custom static files (such as style sheets) here, diff --git a/docs/examples/core_examples.rst b/docs/examples/core_examples.rst index e31ff1e..cf624ae 100644 --- a/docs/examples/core_examples.rst +++ b/docs/examples/core_examples.rst @@ -1,7 +1,7 @@ Core examples ============= -Core examples are hosted on the GitHub repository in the `examples `__ +Core examples are hosted on the GitHub repository in the `examples `__ directory. Each example is designed to be **self-contained and easily forkable**. diff --git a/docs/index.rst b/docs/index.rst index 4c1606e..0ac8fcd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ -``brainscale`` documentation +``braintrace`` documentation ============================ -`brainscale `_ is designed for the scalable online learning of biological neural networks. +`braintrace `_ is designed for the scalable online learning of biological neural networks. ---- @@ -11,26 +11,26 @@ Basic Usage ^^^^^^^^^^^ -Here we show how easy it is to use `brainscale` to build and train a simple SNN/RNN model. +Here we show how easy it is to use `braintrace` to build and train a simple SNN/RNN model. .. code-block:: - import brainscale + import braintrace import brainstate # define models as usual model = brainstate.nn.Sequential( - brainscale.nn.GRU(2, 2), - brainscale.nn.GRU(2, 1), + braintrace.nn.GRU(2, 2), + braintrace.nn.GRU(2, 1), ) # initialize the model brainstate.nn.init_all_states(model) # the only thing you need to do just two lines of code - model = brainscale.ParamDimVjpAlgorithm(model) + model = braintrace.ParamDimVjpAlgorithm(model) model.compile_graph(your_inputs) # train your model as usual @@ -48,19 +48,19 @@ Installation .. code-block:: bash - pip install -U brainscale[cpu] + pip install -U braintrace[cpu] .. tab-item:: GPU (CUDA 12.0) .. code-block:: bash - pip install -U brainscale[cuda12] + pip install -U braintrace[cuda12] .. tab-item:: TPU .. code-block:: bash - pip install -U brainscale[tpu] + pip install -U braintrace[tpu] ---- @@ -70,7 +70,7 @@ See also the ecosystem ^^^^^^^^^^^^^^^^^^^^^^ -``brainscale`` is of part of our `brain simulation ecosystem `_. +``braintrace`` is of part of our `brain simulation ecosystem `_. ---- diff --git a/docs/quickstart/concepts-en.ipynb b/docs/quickstart/concepts-en.ipynb index 567c71c..2563227 100644 --- a/docs/quickstart/concepts-en.ipynb +++ b/docs/quickstart/concepts-en.ipynb @@ -18,18 +18,18 @@ }, "source": [ "\n", - "Welcome to `brainscale`!\n", + "Welcome to `braintrace`!\n", "\n", - "BrainScale is a Python library designed for implementing online learning in neural network models with dynamics. Online learning represents a learning paradigm that enables continuous parameter updates as models receive new data streams. This approach proves particularly valuable in real-world applications, including robotic control systems, agent decision-making processes, and large-scale data stream processing.\n", + "BrainTrace is a Python library designed for implementing online learning in neural network models with dynamics. Online learning represents a learning paradigm that enables continuous parameter updates as models receive new data streams. This approach proves particularly valuable in real-world applications, including robotic control systems, agent decision-making processes, and large-scale data stream processing.\n", "\n", "\n", - "In this section, I will introduce some of the key concepts that are fundamental to understanding and using online learning methods defined in ``brainscale`` . These concepts include:\n", + "In this section, I will introduce some of the key concepts that are fundamental to understanding and using online learning methods defined in ``braintrace`` . These concepts include:\n", "\n", "- Concepts related to build high-Level Neural Networks.\n", "- Concepts related to customize neural network module: ``ETraceVar`` for hidden states, ``ETraceParam`` for weight parameters, and ``ETraceOp`` for input-to-hidden transition.\n", "- Concepts for online learning algorithms ``ETraceAlgorithm``.\n", "\n", - "``brainscale`` is seamlessly integrated in the [brain dynamics programming ecosystem](https://brainmodeling.readthedocs.io/) centred on ``brainstate``. We strongly recommend that you first familiarise yourself with [basic usage of ``brainstate``](https://brainstate.readthedocs.io/), as this will help you better understand how ``brainscale`` works.\n" + "``braintrace`` is seamlessly integrated in the [brain dynamics programming ecosystem](https://brainmodeling.readthedocs.io/) centred on ``brainstate``. We strongly recommend that you first familiarise yourself with [basic usage of ``brainstate``](https://brainstate.readthedocs.io/), as this will help you better understand how ``braintrace`` works.\n" ] }, { @@ -45,9 +45,9 @@ "source": [ "import brainstate \n", "import brainunit as u\n", - "\n", - "import brainscale\n", - "import braintools" + "import braintrace\n", + "import braintools\n", + "import brainpy" ], "outputs": [], "execution_count": 1 @@ -57,9 +57,9 @@ "id": "260b588265bf6265", "metadata": {}, "source": [ - "## 1. Dynamical Models Supported in ``brainscale``\n", + "## 1. Dynamical Models Supported in ``braintrace``\n", "\n", - "``BrainScale`` does not support online learning for arbitrary dynamical models. The dynamical models currently supported by ``BrainScale`` exhibit a specific architectural constraint, as illustrated in the figure below, wherein the \"dynamics\" and \"interactions between dynamics\" are strictly segregated. Models adhering to this architecture can be decomposed into two primary components:\n", + "``BrainTrace`` does not support online learning for arbitrary dynamical models. The dynamical models currently supported by ``BrainTrace`` exhibit a specific architectural constraint, as illustrated in the figure below, wherein the \"dynamics\" and \"interactions between dynamics\" are strictly segregated. Models adhering to this architecture can be decomposed into two primary components:\n", "\n", "- **Dynamics**: This component characterizes the intrinsic dynamics of neurons, encompassing models such as the Leaky Integrate-and-Fire (LIF) neuron model, the FitzHugh-Nagumo model, and Long Short-Term Memory (LSTM) networks. The update of dynamics (hidden states) is implemented through strictly element-wise operations, although the model may incorporate multiple hidden states.\n", "\n", @@ -77,7 +77,7 @@ "collapsed": false }, "source": [ - "To elucidate the class of dynamical models supported by BrainScale, let us examine a fundamental Leaky Integrate-and-Fire (LIF) neural network model. The dynamics of this network are governed by the following differential equations:\n", + "To elucidate the class of dynamical models supported by BrainTrace, let us examine a fundamental Leaky Integrate-and-Fire (LIF) neural network model. The dynamics of this network are governed by the following differential equations:\n", "\n", "$$\n", "\\begin{aligned}\n", @@ -106,7 +106,7 @@ "\\end{aligned}\n", "$$\n", "\n", - "Notably, the dynamics of the LIF neurons are updated through element-wise operations, while the interaction component is implemented via matrix multiplication. All dynamical models supported by BrainScale can be decomposed into similar `dynamics` and `interaction` components. It is particularly worth noting that this architecture encompasses the majority of recurrent neural network models, thus enabling BrainScale to support online learning for a wide range of recurrent neural network architectures." + "Notably, the dynamics of the LIF neurons are updated through element-wise operations, while the interaction component is implemented via matrix multiplication. All dynamical models supported by BrainTrace can be decomposed into similar `dynamics` and `interaction` components. It is particularly worth noting that this architecture encompasses the majority of recurrent neural network models, thus enabling BrainTrace to support online learning for a wide range of recurrent neural network architectures." ] }, { @@ -116,13 +116,13 @@ "collapsed": false }, "source": [ - "## 2. `brainscale.nn`: Constructing Neural Networks with Online Learning Support\n", + "## 2. `braintrace.nn`: Constructing Neural Networks with Online Learning Support\n", "\n", - "The construction of neural network models supporting online learning in BrainScale follows identical conventions as those employed in `brainstate`. For comprehensive tutorials, please refer to the documentation on [Building Artificial Neural Networks](https://brainstate.readthedocs.io/tutorials/artificial_neural_networks-en.html) and [Building Spiking Neural Networks](https://brainstate.readthedocs.io/tutorials/spiking_neural_networks-en.html).\n", + "The construction of neural network models supporting online learning in BrainTrace follows identical conventions as those employed in `brainstate`. For comprehensive tutorials, please refer to the documentation on [Building Artificial Neural Networks](https://brainstate.readthedocs.io/tutorials/artificial_neural_networks-en.html) and [Building Spiking Neural Networks](https://brainstate.readthedocs.io/tutorials/spiking_neural_networks-en.html).\n", "\n", - "The sole distinction lies in the requirement to utilize components from the [`brainscale.nn` module](../apis/nn.rst) for model construction. These components represent extensions of `brainstate.nn` module, specifically engineered to provide modular units with online learning capabilities.\n", + "The sole distinction lies in the requirement to utilize components from the [`braintrace.nn` module](../apis/nn.rst) for model construction. These components represent extensions of `brainstate.nn` module, specifically engineered to provide modular units with online learning capabilities.\n", "\n", - "Below, we present a basic implementation demonstrating the construction of a Leaky Integrate-and-Fire (LIF) neural network using the `brainscale.nn` module." + "Below, we present a basic implementation demonstrating the construction of a Leaky Integrate-and-Fire (LIF) neural network using the `braintrace.nn` module." ] }, { @@ -149,7 +149,7 @@ " ):\n", " super().__init__()\n", "\n", - " # Using the LIF model in brainscale.nn\n", + " # Using the LIF model in braintrace.nn\n", " self.neu = brainpy.state.LIF(n_rec, tau=tau_mem, spk_fun=spk_fun, spk_reset=spk_reset, V_th=V_th)\n", "\n", " # Constructing input and recurrent connection weights\n", @@ -159,8 +159,8 @@ "\n", " # Using delta synaptic projections to construct input and recurrent connections\n", " self.syn = brainpy.state.DeltaProj(\n", - " # Using the Linear model in brainscale.nn\n", - " comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=braintools.init.ZeroInit(unit=u.mV)),\n", + " # Using the Linear model in braintrace.nn\n", + " comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=braintools.init.ZeroInit(unit=u.mV)),\n", " post=self.neu\n", " )\n", "\n", @@ -209,10 +209,10 @@ " # Building the GRU Layer\n", " self.layers = []\n", " for i in range(n_layer - 1):\n", - " # Using the GRUCell model within brainscale.nn\n", - " self.layers.append(brainscale.nn.GRUCell(n_in, n_rec))\n", + " # Using the GRUCell model within braintrace.nn\n", + " self.layers.append(braintrace.nn.GRUCell(n_in, n_rec))\n", " n_in = n_rec\n", - " self.layers.append(brainscale.nn.GRUCell(n_in, n_out))\n", + " self.layers.append(braintrace.nn.GRUCell(n_in, n_out))\n", "\n", " def update(self, x):\n", " # Updating the GRU Layer\n", @@ -229,7 +229,7 @@ "metadata": { "collapsed": false }, - "source": "As demonstrated, the process of constructing neural network models using the [`brainscale.nn` module](../apis/nn.rst) maintains complete procedural equivalence with the construction methodology employed in the [`brainstate.nn` module](https://brainstate.readthedocs.io/apis/nn.html). This architectural parallelism enables direct utilization of existing `brainstate` tutorials for developing neural network models with online learning capabilities." + "source": "As demonstrated, the process of constructing neural network models using the [`braintrace.nn` module](../apis/nn.rst) maintains complete procedural equivalence with the construction methodology employed in the [`brainstate.nn` module](https://brainstate.readthedocs.io/apis/nn.html). This architectural parallelism enables direct utilization of existing `brainstate` tutorials for developing neural network models with online learning capabilities." }, { "cell_type": "markdown", @@ -240,20 +240,20 @@ "source": [ "## 3. `ETraceState`, `ETraceParam`, and `ETraceOp`: Customizing Network Modules\n", "\n", - "While the `brainscale.nn` module provides fundamental network components, it does not encompass all possible network dynamics. Consequently, BrainScale implements a mechanism for customizing module development through three primary classes: `ETraceState`, `ETraceParam`, and `ETraceOp`.\n", + "While the `braintrace.nn` module provides fundamental network components, it does not encompass all possible network dynamics. Consequently, BrainTrace implements a mechanism for customizing module development through three primary classes: `ETraceState`, `ETraceParam`, and `ETraceOp`.\n", "\n", "**Core Components**\n", "\n", "- **`brainstate.HiddenState`**: Represents the hidden states $\\mathbf{h}$ within modules, defining dynamical states such as membrane potentials in LIF neurons or postsynaptic conductances in exponential synaptic models.\n", "\n", - "- **`brainscale.ETraceOp`**: Describe network connections, or how input data is used to compute postsynaptic currents based on model parameters, such as linear matrix multiplication, sparse matrix multiplication, and convolution operations.\n", + "- **`braintrace.ETraceOp`**: Describe network connections, or how input data is used to compute postsynaptic currents based on model parameters, such as linear matrix multiplication, sparse matrix multiplication, and convolution operations.\n", "\n", - "- **`brainscale.ETraceParam`**: Corresponds to model parameters $\\theta$ within modules, encompassing elements such as weight matrices for linear matrix multiplication and adaptive time constants in LIF neurons. All parameters requiring gradient updates during training must be defined within `ETraceParam`.\n", + "- **`braintrace.ETraceParam`**: Corresponds to model parameters $\\theta$ within modules, encompassing elements such as weight matrices for linear matrix multiplication and adaptive time constants in LIF neurons. All parameters requiring gradient updates during training must be defined within `ETraceParam`.\n", "\n", "\n", "**Foundational Framework**\n", "\n", - "These three components—`ETraceState`, `ETraceParam`, and `ETraceOp`—constitute the fundamental conceptual framework underlying neural network models with online learning capabilities in BrainScale.\n", + "These three components—`ETraceState`, `ETraceParam`, and `ETraceOp`—constitute the fundamental conceptual framework underlying neural network models with online learning capabilities in BrainTrace.\n", "\n", "In the following sections, we will present a series of illustrative examples demonstrating the practical implementation of custom network modules using `ETraceState`, `ETraceParam`, and `ETraceOp`." ] @@ -361,7 +361,7 @@ "source": [ "In the code above, we implement the `LIF` model through inheritance from `brainpy.state.Neuron`. The class incorporates an `ETraceState` class variable `self.V` that characterizes the dynamical state of the membrane potential. The `init_state` method establishes the initial conditions for the membrane potential dynamics, while the `update` method implements the temporal evolution of these dynamics.\n", "\n", - "This implementation maintains substantial structural similarity with the `LIF` class definition in `brainstate`, with one crucial distinction: whereas `brainstate` employs `brainstate.HiddenState` to represent the membrane potential dynamics, `brainscale` utilizes `brainstate.ETraceState` to explicitly designate this dynamical state for online learning applications." + "This implementation maintains substantial structural similarity with the `LIF` class definition in `brainstate`, with one crucial distinction: whereas `brainstate` employs `brainstate.HiddenState` to represent the membrane potential dynamics, `braintrace` utilizes `brainstate.ETraceState` to explicitly designate this dynamical state for online learning applications." ] }, { @@ -371,7 +371,7 @@ "collapsed": false }, "source": [ - "Therefore, we say that `brainstate.HiddenState` can be conceptualized as the counterpart to `brainstate.HiddenState`, specifically designed for defining model states that require eligibility trace updates. Should model states be defined using `brainstate.HiddenState` rather than `brainstate.HiddenState`, the online learning compiler in `brainscale` will fail to recognize these states. This oversight results in the compiled online learning rules being ineffective for the affected states, potentially leading to erroneous or omitted gradient updates in the model.\n", + "Therefore, we say that `brainstate.HiddenState` can be conceptualized as the counterpart to `brainstate.HiddenState`, specifically designed for defining model states that require eligibility trace updates. Should model states be defined using `brainstate.HiddenState` rather than `brainstate.HiddenState`, the online learning compiler in `braintrace` will fail to recognize these states. This oversight results in the compiled online learning rules being ineffective for the affected states, potentially leading to erroneous or omitted gradient updates in the model.\n", "\n", "But we should still be aware that there are obvious differences between `ETraceState` and `HiddenState`:\n", "- `brainstate.HiddenState`: Explicitly marks states for eligibility trace computation\n", @@ -400,13 +400,13 @@ "\n", "\n", "\n", - "BrainScale provides several built-in operators for common modeling operations, including linear matrix multiplication, sparse matrix multiplication, and convolution. These include:\n", + "BrainTrace provides several built-in operators for common modeling operations, including linear matrix multiplication, sparse matrix multiplication, and convolution. These include:\n", "\n", - "* [`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst): Standard matrix multiplication.\n", - "* [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst): Standard convolution operation.\n", - "* [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst): Sparse matrix-vector multiplication.\n", - "* [`brainscale.LoraOp`](../apis/generated/brainscale.LoraOp.rst): Low-Rank Adaptation (LoRA) operation.\n", - "* [`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst): Element-wise operations.\n" + "* [`braintrace.MatMulOp`](../apis/generated/braintrace.MatMulOp.rst): Standard matrix multiplication.\n", + "* [`braintrace.ConvOp`](../apis/generated/braintrace.ConvOp.rst): Standard convolution operation.\n", + "* [`braintrace.SpMatMulOp`](../apis/generated/braintrace.SpMatMulOp.rst): Sparse matrix-vector multiplication.\n", + "* [`braintrace.LoraOp`](../apis/generated/braintrace.LoraOp.rst): Low-Rank Adaptation (LoRA) operation.\n", + "* [`braintrace.ElemWiseOp`](../apis/generated/braintrace.ElemWiseOp.rst): Element-wise operations.\n" ] }, { @@ -422,10 +422,10 @@ "id": "1dc068e8", "metadata": {}, "source": [ - "`ETraceParam` in BrainScale is used to define trainable model parameters. It takes the following form:\n", + "`ETraceParam` in BrainTrace is used to define trainable model parameters. It takes the following form:\n", "\n", "```python\n", - "param = brainscale.ETraceParam(parameters, op)\n", + "param = braintrace.ETraceParam(parameters, op)\n", "```\n", "\n", "Here, `parameters` refers to the model parameters, and `op` is an instantiated `ETraceOp`. The typical usage pattern is:\n", @@ -449,14 +449,14 @@ "source": [ "def generate_weight(\n", " n_in, n_out, init: Callable = braintools.init.KaimingNormal()\n", - ") -> brainscale.ETraceParam:\n", + ") -> braintrace.ETraceParam:\n", " weight = init([n_in, n_out])\n", " bias = braintools.init.ZeroInit()([n_out])\n", " \n", " # Here is the most crucial step, we define an ETraceParam class to describe the weight matrix and bias vector\n", - " return brainscale.ETraceParam(\n", + " return braintrace.ETraceParam(\n", " {'weight': weight, 'bias': bias}, # model parameters\n", - " brainscale.MatMulOp() # operation\n", + " braintrace.MatMulOp() # operation\n", " )" ], "outputs": [], @@ -469,7 +469,7 @@ "source": [ "In the code above, we define a `generate_weight` function that produces weight matrices and bias vectors. This function returns an `ETraceParam` object that encapsulates these parameters.\n", "\n", - "`brainscale.ETraceParam` serves as the counterpart to `brainstate.ParamState`, specifically designed for model parameters requiring eligibility trace updates. When model parameters $\\theta$ are defined using `brainscale.ETraceParam`, the online learning compiler in `brainscale` implements temporally-dependent gradient updates according to the following formula:\n", + "`braintrace.ETraceParam` serves as the counterpart to `brainstate.ParamState`, specifically designed for model parameters requiring eligibility trace updates. When model parameters $\\theta$ are defined using `braintrace.ETraceParam`, the online learning compiler in `braintrace` implements temporally-dependent gradient updates according to the following formula:\n", "\n", "$$\n", "\\nabla_\\theta \\mathcal{L}=\\sum_{t} \\frac{\\partial \\mathcal{L}^{t}}{\\partial \\mathbf{h}^{t}} \\sum_{k=1}^t \\frac{\\partial \\mathbf{h}^t}{\\partial \\boldsymbol{\\theta}^k},\n", @@ -483,7 +483,7 @@ "\\nabla_\\theta \\mathcal{L}=\\sum_{t} \\frac{\\partial \\mathcal{L}^{t}}{\\partial \\mathbf{h}^{t}} \\frac{\\partial \\mathbf{h}^t}{\\partial \\boldsymbol{\\theta}^t}.\n", "$$\n", "\n", - "This implementation distinction signifies that in `brainscale`'s online learning framework, parameters defined as `brainstate.ParamState` are treated as those not requiring eligibility trace updates, thereby forfeiting the ability to compute gradients with temporal dependencies. This architectural design enhances the flexibility of parameter update patterns, thereby increasing the customizability of gradient computation mechanisms." + "This implementation distinction signifies that in `braintrace`'s online learning framework, parameters defined as `brainstate.ParamState` are treated as those not requiring eligibility trace updates, thereby forfeiting the ability to compute gradients with temporal dependencies. This architectural design enhances the flexibility of parameter update patterns, thereby increasing the customizability of gradient computation mechanisms." ] }, { @@ -528,10 +528,10 @@ " weight = braintools.init.param(w_init, [self.in_size[-1], self.out_size[-1]], allow_none=False)\n", " \n", " # operation\n", - " op = brainscale.MatMulOp()\n", + " op = braintrace.MatMulOp()\n", " \n", " # Here is the most crucial step, we define an ETraceParam class to describe the weight matrix and operations\n", - " self.weight_op = brainscale.ETraceParam(weight, op)\n", + " self.weight_op = braintrace.ETraceParam(weight, op)\n", "\n", " def update(self, x):\n", " # Operation of ETraceParam\n", @@ -561,11 +561,11 @@ "source": [ "## 4. `ETraceAlgorithm`: Online Learning Algorithms\n", "\n", - "`ETraceAlgorithm` represents another fundamental concept in the BrainScale framework, defining both the eligibility trace update process during model state evolution and the gradient update rules for model parameters. Implemented as an abstract class, `ETraceAlgorithm` serves as a specialized framework for describing various forms of online learning algorithms within BrainScale.\n", + "`ETraceAlgorithm` represents another fundamental concept in the BrainTrace framework, defining both the eligibility trace update process during model state evolution and the gradient update rules for model parameters. Implemented as an abstract class, `ETraceAlgorithm` serves as a specialized framework for describing various forms of online learning algorithms within BrainTrace.\n", "\n", - "The algorithmic support provided by `brainscale.ETraceAlgorithm` is founded upon the three fundamental concepts previously introduced: `ETraceState`, `ETraceParam`, and `ETraceOp`. \n", + "The algorithmic support provided by `braintrace.ETraceAlgorithm` is founded upon the three fundamental concepts previously introduced: `ETraceState`, `ETraceParam`, and `ETraceOp`. \n", "\n", - "`brainscale.ETraceAlgorithm` implements a flexible online learning compiler that enables online learning capabilities for any neural network model constructed using these three foundational concepts." + "`braintrace.ETraceAlgorithm` implements a flexible online learning compiler that enables online learning capabilities for any neural network model constructed using these three foundational concepts." ] }, { @@ -575,21 +575,21 @@ "collapsed": false }, "source": [ - "Specifically, BrainScale currently supports the following online learning algorithms:\n", + "Specifically, BrainTrace currently supports the following online learning algorithms:\n", "\n", - "1. [`brainscale.IODimVjpAlgorithm`](../apis/generated/brainscale.IODimVjpAlgorithm.rst) or [`brainscale.ES_D_RTRL`](../apis/generated/brainscale.ES_D_RTRL.rst)\n", + "1. [`braintrace.IODimVjpAlgorithm`](../apis/generated/braintrace.IODimVjpAlgorithm.rst) or [`braintrace.ES_D_RTRL`](../apis/generated/braintrace.ES_D_RTRL.rst)\n", " - Implements the ES-D-RTRL algorithm for online learning\n", " - Achieves $O(N)$ memory and computational complexity for online gradient computation\n", " - Optimized for large-scale spiking neural network models\n", " - Detailed algorithm specifications are available in [our paper](https://doi.org/10.1101/2024.09.24.614728)\n", "\n", - "2. [`brainscale.ParamDimVjpAlgorithm`](../apis/generated/brainscale.ParamDimVjpAlgorithm.rst) or [`brainscale.D_RTRL`](../apis/generated/brainscale.D_RTRL.rst)\n", + "2. [`braintrace.ParamDimVjpAlgorithm`](../apis/generated/braintrace.ParamDimVjpAlgorithm.rst) or [`braintrace.D_RTRL`](../apis/generated/braintrace.D_RTRL.rst)\n", " - Utilizes the D-RTRL algorithm for online learning\n", " - Features $O(N^2)$ memory and computational complexity for online gradient computation\n", " - Applicable to both recurrent neural networks and spiking neural network models\n", " - Complete algorithmic details are documented in [our paper](https://doi.org/10.1101/2024.09.24.614728)\n", "\n", - "3. [`brainscale.HybridDimVjpAlgorithm`](../apis/generated/brainscale.HybridDimVjpAlgorithm.rst)\n", + "3. [`braintrace.HybridDimVjpAlgorithm`](../apis/generated/braintrace.HybridDimVjpAlgorithm.rst)\n", " - Implements selective application of ES-D-RTRL or D-RTRL algorithms for parameter updates\n", " - Preferentially employs D-RTRL for convolutional layers and highly sparse connections\n", " - Optimizes memory and computational complexity of parameter updates through adaptive algorithm selection\n", @@ -604,7 +604,7 @@ "collapsed": false }, "source": [ - "In the following demonstration, we will illustrate the process of constructing a neural network model with online learning capabilities using `brainscale.ETraceAlgorithm`. This example will serve to exemplify the practical implementation of the concepts discussed above." + "In the following demonstration, we will illustrate the process of constructing a neural network model with online learning capabilities using `braintrace.ETraceAlgorithm`. This example will serve to exemplify the practical implementation of the concepts discussed above." ] }, { @@ -625,7 +625,7 @@ " brainstate.nn.init_all_states(model)\n", " \n", " # The model is fed into an online learning algorithm with a view to online learning\n", - " model = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.99)\n", + " model = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.99)\n", " \n", " # Compile the model's eligibility trace based on one input data. \n", " # Thereafter, the model is called to update not only the state \n", @@ -725,12 +725,12 @@ "source": [ "## 5. Conclusion\n", "\n", - "`BrainScale` provides a comprehensive and elegant framework for online learning, with its core concepts organized into the following hierarchical layers:\n", + "`BrainTrace` provides a comprehensive and elegant framework for online learning, with its core concepts organized into the following hierarchical layers:\n", "\n", "1. Infrastructure Layer\n", " - Supports specific dynamical model architectures with strict separation between \"dynamics\" and \"interactions\"\n", " - Built upon the `BrainState` ecosystem, maintaining full compatibility with its programming paradigm\n", - " - Provides ready-to-use neural network components through the `brainscale.nn` module\n", + " - Provides ready-to-use neural network components through the `braintrace.nn` module\n", "\n", "2. Core Concepts Layer\n", " - `ETraceState`: Designates hidden states requiring eligibility trace updates\n", @@ -748,12 +748,12 @@ " - Model compilation generating eligibility trace computation graphs\n", " - Concurrent updates of model states and eligibility traces through forward propagation\n", "\n", - "BrainScale's distinctive architecture encapsulates complex online learning algorithms behind concise interfaces while providing flexible customization mechanisms. This design philosophy achieves a balance between:\n", + "BrainTrace's distinctive architecture encapsulates complex online learning algorithms behind concise interfaces while providing flexible customization mechanisms. This design philosophy achieves a balance between:\n", "- High performance and usability\n", "- Seamless integration with the existing BrainState ecosystem\n", "- Intuitive and efficient construction of online learning neural networks\n", "\n", - "Through this architectural approach, BrainScale transforms the development and training of online learning neural networks into an intuitive and efficient process, providing a powerful toolkit for neural computation research." + "Through this architectural approach, BrainTrace transforms the development and training of online learning neural networks into an intuitive and efficient process, providing a powerful toolkit for neural computation research." ] } ], diff --git a/docs/quickstart/concepts-zh.ipynb b/docs/quickstart/concepts-zh.ipynb index 0344909..e3a44f4 100644 --- a/docs/quickstart/concepts-zh.ipynb +++ b/docs/quickstart/concepts-zh.ipynb @@ -17,17 +17,17 @@ "collapsed": false }, "source": [ - "欢迎来到``brainscale``的世界!\n", + "欢迎来到``braintrace``的世界!\n", "\n", - "``brainscale``是一个支持动力学神经网络模型的在线学习Python库。在线学习(online learning)是一种学习范式,它允许模型在不断地接收新数据的同时,不断地更新自己的参数。这种学习方式在许多现实世界的应用中非常有用,比如在机器人控制、智能体的决策制定、以及大规模数据流的处理中。\n", + "``braintrace``是一个支持动力学神经网络模型的在线学习Python库。在线学习(online learning)是一种学习范式,它允许模型在不断地接收新数据的同时,不断地更新自己的参数。这种学习方式在许多现实世界的应用中非常有用,比如在机器人控制、智能体的决策制定、以及大规模数据流的处理中。\n", "\n", - "在这个章节,我将会介绍一些关键概念,这些概念是理解和使用``brainscale``在线学习的基础。这些概念包括:\n", + "在这个章节,我将会介绍一些关键概念,这些概念是理解和使用``braintrace``在线学习的基础。这些概念包括:\n", "\n", "- 如何构建支持在线学习的高层次神经网络模型。\n", "- 用于定制化网络模块的 模型状态``ETraceState``、模型参数``ETraceParam``和 模型交互算子``ETraceOp``。\n", "- 在线学习算法 ``ETraceAlgorithm``。\n", "\n", - "``brainscale``精密地整合在以``brainstate``为中心的[脑动力学编程生态系统](https://brainmodeling.readthedocs.io/)中。我们强烈建议您首先熟悉[``brainstate``的基本用法](https://brainstate.readthedocs.io/),以此能帮助您更好地理解``brainscale``的工作原理。" + "``braintrace``精密地整合在以``brainstate``为中心的[脑动力学编程生态系统](https://brainmodeling.readthedocs.io/)中。我们强烈建议您首先熟悉[``brainstate``的基本用法](https://brainstate.readthedocs.io/),以此能帮助您更好地理解``braintrace``的工作原理。" ] }, { @@ -41,10 +41,12 @@ } }, "source": [ + "import brainpy\n", + "import braintools\n", "import brainstate\n", "import brainunit as u\n", "\n", - "import brainscale" + "import braintrace" ], "outputs": [], "execution_count": 1 @@ -54,9 +56,9 @@ "id": "3b5da7326ca97f38", "metadata": {}, "source": [ - "## 1. ``brainscale``支持的动力学模型\n", + "## 1. ``braintrace``支持的动力学模型\n", "\n", - "``brainscale``并不支持所有类型的动力学模型在线学习。它当前支持的动力学模型具有特定的结构,如下图所示,其中“动力学(dynamics)”与“动力学之间的交互(interaction)”是严格分开的。这种结构的模型可以分解为两个主要部分:\n", + "``braintrace``并不支持所有类型的动力学模型在线学习。它当前支持的动力学模型具有特定的结构,如下图所示,其中“动力学(dynamics)”与“动力学之间的交互(interaction)”是严格分开的。这种结构的模型可以分解为两个主要部分:\n", "\n", "- **动力学部分**:这一部分描述了神经元内在的动力学行为,例如 LIF 神经元模型、FitzHugh-Nagumo 模型以及长短期记忆网络(LSTM)。动力学状态(hidden states)的更新是严格逐元素运算(element-wise operations),但模型可以包含多个动力学状态。\n", "- **交互部分**:这一部分描述了神经元之间的交互,例如权重矩阵和连接矩阵。模型动力学之间的交互可以通过标准的矩阵乘法、卷积操作或稀疏操作等方式实现。\n", @@ -71,7 +73,7 @@ "id": "21a8bbad22acafdd", "metadata": {}, "source": [ - "让我们通过一个简单的网络模型示例来阐明``brainscale``所支持的动力学模型。我们考虑一个基本的LIF神经元网络,其动力学由以下微分方程描述:\n", + "让我们通过一个简单的网络模型示例来阐明``braintrace``所支持的动力学模型。我们考虑一个基本的LIF神经元网络,其动力学由以下微分方程描述:\n", "\n", "$$\n", "\\begin{aligned}\n", @@ -100,7 +102,7 @@ "\\end{aligned}\n", "$$\n", "\n", - "可以看到,LIF神经元的动力学更新是逐元素进行的,而交互部分的更新则通过矩阵乘法实现。所有`brainscale`支持的动力学模型均可分解为这样的动力学部分和交互部分。值得注意的是,绝大多数循环神经网络模型都符合这一结构,因此`brainscale`能够支持大部分循环神经网络模型的在线学习。\n" + "可以看到,LIF神经元的动力学更新是逐元素进行的,而交互部分的更新则通过矩阵乘法实现。所有`braintrace`支持的动力学模型均可分解为这样的动力学部分和交互部分。值得注意的是,绝大多数循环神经网络模型都符合这一结构,因此`braintrace`能够支持大部分循环神经网络模型的在线学习。\n" ] }, { @@ -108,13 +110,13 @@ "id": "d0344ca465c9401e", "metadata": {}, "source": [ - "## 2. ``brainscale.nn``:构建支持在线学习的神经网络\n", + "## 2. ``braintrace.nn``:构建支持在线学习的神经网络\n", "\n", - "在``brainscale``中,我们可以使用与``brainstate``完全相同的语法来构建支持在线学习的神经网络模型。有关详细教程,请参考 [构建人工神经网络](https://brainstate.readthedocs.io/tutorials/artificial_neural_networks-zh.html) 和 [构建脉冲神经网络](https://brainstate.readthedocs.io/tutorials/spiking_neural_networks-zh.html)。\n", + "在``braintrace``中,我们可以使用与``brainstate``完全相同的语法来构建支持在线学习的神经网络模型。有关详细教程,请参考 [构建人工神经网络](https://brainstate.readthedocs.io/tutorials/artificial_neural_networks-zh.html) 和 [构建脉冲神经网络](https://brainstate.readthedocs.io/tutorials/spiking_neural_networks-zh.html)。\n", "\n", - "唯一的区别在于,我们需要使用[``brainscale.nn``模块](../apis/nn.rst)中的组件来构建神经网络模型。这些组件是``brainstate.nn``的扩展,专门设计用于支持在线学习的单元模块。\n", + "唯一的区别在于,我们需要使用[``braintrace.nn``模块](../apis/nn.rst)中的组件来构建神经网络模型。这些组件是``brainstate.nn``的扩展,专门设计用于支持在线学习的单元模块。\n", "\n", - "以下是一个简单示例,展示如何使用``brainscale.nn``模块构建一个基本的LIF神经元网络。" + "以下是一个简单示例,展示如何使用``braintrace.nn``模块构建一个基本的LIF神经元网络。" ] }, { @@ -141,7 +143,7 @@ " ):\n", " super().__init__()\n", "\n", - " # 使用 brainscale.nn 内的 LIF 模型\n", + " # 使用 braintrace.nn 内的 LIF 模型\n", " self.neu = brainpy.state.LIF(n_rec, tau=tau_mem, spk_fun=spk_fun, spk_reset=spk_reset, V_th=V_th)\n", "\n", " # 构建输入和循环连接权重\n", @@ -151,8 +153,8 @@ "\n", " # 使用 delta 突触投射來构建输入和循环连接\n", " self.syn = brainpy.state.DeltaProj(\n", - " # 使用 brainscale.nn 内的 Linear 模型\n", - " comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=braintools.init.ZeroInit(unit=u.mV)),\n", + " # 使用 braintrace.nn 内的 Linear 模型\n", + " comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=braintools.init.ZeroInit(unit=u.mV)),\n", " post=self.neu\n", " )\n", "\n", @@ -200,10 +202,10 @@ " # 构建 GRU 层\n", " self.layers = []\n", " for i in range(n_layer - 1):\n", - " # 使用 brainscale.nn 内的 GRUCell 模型\n", - " self.layers.append(brainscale.nn.GRUCell(n_in, n_rec))\n", + " # 使用 braintrace.nn 内的 GRUCell 模型\n", + " self.layers.append(braintrace.nn.GRUCell(n_in, n_rec))\n", " n_in = n_rec\n", - " self.layers.append(brainscale.nn.GRUCell(n_in, n_out))\n", + " self.layers.append(braintrace.nn.GRUCell(n_in, n_out))\n", "\n", " def update(self, x):\n", " # 更新 GRU 层\n", @@ -220,7 +222,7 @@ "metadata": { "collapsed": false }, - "source": "可以看到,基于[``brainscale.nn``模块](../apis/nn.rst)构建神经网络模型的过程与基于[``brainstate.nn``模块](https://brainstate.readthedocs.io/apis/nn.html)的构建过程完全相同。这意味着,您可以直接利用``brainstate``的教程来构建支持在线学习的神经网络模型。" + "source": "可以看到,基于[``braintrace.nn``模块](../apis/nn.rst)构建神经网络模型的过程与基于[``brainstate.nn``模块](https://brainstate.readthedocs.io/apis/nn.html)的构建过程完全相同。这意味着,您可以直接利用``brainstate``的教程来构建支持在线学习的神经网络模型。" }, { "cell_type": "markdown", @@ -231,14 +233,14 @@ "source": [ "## 3. ``ETraceState``、``ETraceParam``和``ETraceOp``:定制化网络模块\n", "\n", - "尽管``brainscale.nn``模块提供了一些基本的网络模块,但它并未涵盖所有可能的网络动力学。因此,我们需要一种机制,允许用户定制网络模块。在``brainscale``中,我们提供了``ETraceState``、``ETraceParam``和``ETraceOp``三个类,供用户进行模块定制。\n", + "尽管``braintrace.nn``模块提供了一些基本的网络模块,但它并未涵盖所有可能的网络动力学。因此,我们需要一种机制,允许用户定制网络模块。在``braintrace``中,我们提供了``ETraceState``、``ETraceParam``和``ETraceOp``三个类,供用户进行模块定制。\n", "\n", "- **``brainstate.HiddenState``**:代表模块中的模型状态$\\mathbf{h}$,用于定义模型的动力学状态,例如LIF神经元的膜电位或指数突触模型的突触后电导。\n", - "- **``brainscale.ETraceOp``**:用于描述网络的连接,或者输入数据如何基于模型参数计算突触后电流的操作,如线性矩阵乘法、稀疏矩阵乘法和卷积操作。\n", - "- **``brainscale.ETraceParam``**:对应模块中的模型参数$\\theta$,用于定义模型参数,例如线性矩阵乘法的权重矩阵,还可以用于LIF神经元中自适应学习的时间常数等。所有需要在训练过程中进行梯度更新的参数都应在``ETraceParam``中定义。\n", + "- **``braintrace.ETraceOp``**:用于描述网络的连接,或者输入数据如何基于模型参数计算突触后电流的操作,如线性矩阵乘法、稀疏矩阵乘法和卷积操作。\n", + "- **``braintrace.ETraceParam``**:对应模块中的模型参数$\\theta$,用于定义模型参数,例如线性矩阵乘法的权重矩阵,还可以用于LIF神经元中自适应学习的时间常数等。所有需要在训练过程中进行梯度更新的参数都应在``ETraceParam``中定义。\n", "\n", "\n", - "``ETraceState``、``ETraceParam``和``ETraceOp``是``brainscale``中的三个基本概念,构成了支持在线学习的神经网络模型的基础。\n", + "``ETraceState``、``ETraceParam``和``ETraceOp``是``braintrace``中的三个基本概念,构成了支持在线学习的神经网络模型的基础。\n", "\n", "接下来,让我们通过一系列简单示例来说明如何使用``ETraceState``、``ETraceParam``和``ETraceOp``进行网络模块的定制。" ] @@ -343,7 +345,7 @@ "collapsed": false }, "source": [ - "在上面的代码中,我们继承``brainpy.state.Neuron``定义了这个``LIF``模型。这个类包含了一个``ETraceState``类变量``self.V``,用于描述膜电位的动力学状态。在``init_state``方法中,我们初始化了膜电位的动力学状态。在``update``方法中,我们更新了膜电位的动力学状态。实际上,这个类的定义与``brainstate``中的``LIF``类的定义基本上是一模一样的,唯一不同的地方在于``brainstate``使用``brainstate.HiddenState``来描述膜电位的动力学状态,而``brainscale``使用``brainstate.ETraceState``来标记该膜电位的动力学状态是需要用于在线学习的。" + "在上面的代码中,我们继承``brainpy.state.Neuron``定义了这个``LIF``模型。这个类包含了一个``ETraceState``类变量``self.V``,用于描述膜电位的动力学状态。在``init_state``方法中,我们初始化了膜电位的动力学状态。在``update``方法中,我们更新了膜电位的动力学状态。实际上,这个类的定义与``brainstate``中的``LIF``类的定义基本上是一模一样的,唯一不同的地方在于``brainstate``使用``brainstate.HiddenState``来描述膜电位的动力学状态,而``braintrace``使用``brainstate.ETraceState``来标记该膜电位的动力学状态是需要用于在线学习的。" ] }, { @@ -355,7 +357,7 @@ "source": [ "因此,我们可以将``brainstate.HiddenState``视为与``brainstate.HiddenState``相对应的概念,专门用于定义需要进行资格迹(eligibility trace)更新的模型状态。\n", "\n", - "如果在程序中将模型状态定义为``brainstate.HiddenState``而非``brainstate.HiddenState``,则``brainscale``的在线学习编译器将无法识别该状态,导致编译后的在线学习规则不再作用于该状态,从而引发模型的梯度更新错误或遗漏。" + "如果在程序中将模型状态定义为``brainstate.HiddenState``而非``brainstate.HiddenState``,则``braintrace``的在线学习编译器将无法识别该状态,导致编译后的在线学习规则不再作用于该状态,从而引发模型的梯度更新错误或遗漏。" ] }, { @@ -382,13 +384,13 @@ "source": [ "\n", "\n", - "BrainScale 内置了许多常见的模型操作,包括线性矩阵乘法、稀疏矩阵乘法、卷积操作,包括:\n", + "BrainTrace 内置了许多常见的模型操作,包括线性矩阵乘法、稀疏矩阵乘法、卷积操作,包括:\n", "\n", - "- [`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst):标准的矩阵乘法操作。\n", - "- [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst):标准的卷积操作。\n", - "- [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst):稀疏矩阵乘法操作。\n", - "- [`brainscale.LoraOp`](../apis/generated/brainscale.LoraOp.rst):低秩适应(LoRA)操作。\n", - "- [`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst):逐元素操作。\n", + "- [`braintrace.MatMulOp`](../apis/generated/braintrace.MatMulOp.rst):标准的矩阵乘法操作。\n", + "- [`braintrace.ConvOp`](../apis/generated/braintrace.ConvOp.rst):标准的卷积操作。\n", + "- [`braintrace.SpMatMulOp`](../apis/generated/braintrace.SpMatMulOp.rst):稀疏矩阵乘法操作。\n", + "- [`braintrace.LoraOp`](../apis/generated/braintrace.LoraOp.rst):低秩适应(LoRA)操作。\n", + "- [`braintrace.ElemWiseOp`](../apis/generated/braintrace.ElemWiseOp.rst):逐元素操作。\n", "\n", "\n" ], @@ -403,10 +405,10 @@ "source": [ "### 3.3 ``ETraceParam``: 模型参数\n", "\n", - "``ETraceParam``在brainscale中可用于定义需要训练的模型参数,它接收如下参数格式:\n", + "``ETraceParam``在braintrace中可用于定义需要训练的模型参数,它接收如下参数格式:\n", "\n", "```python\n", - "param = brainscale.ETraceParam(parameters, op)\n", + "param = braintrace.ETraceParam(parameters, op)\n", "```\n", "\n", "其中,`parameters`是模型参数,`op`是一个实例化的``ETraceOp``。其基本用法为:\n", @@ -430,14 +432,14 @@ "source": [ "def generate_weight(\n", " n_in, n_out, init: Callable = braintools.init.KaimingNormal()\n", - ") -> brainscale.ETraceParam:\n", + ") -> braintrace.ETraceParam:\n", " weight = init([n_in, n_out])\n", " bias = braintools.init.ZeroInit()([n_out])\n", "\n", " # 这里是最关键的一步,我们定义了一个 ETraceParam 类,用于描述权重矩阵和偏置向量\n", - " return brainscale.ETraceParam(\n", + " return braintrace.ETraceParam(\n", " {'weight': weight, 'bias': bias}, # 模型参数\n", - " brainscale.MatMulOp() # 操作\n", + " braintrace.MatMulOp() # 操作\n", " )" ], "outputs": [], @@ -450,7 +452,7 @@ "source": [ "在上述代码中,我们定义了一个``generate_weight``函数,用于生成权重矩阵和偏置向量。该函数返回一个``ETraceParam``对象,以描述权重矩阵和偏置向量。\n", "\n", - "``brainscale.ETraceParam``是与``brainstate.ParamState``相对应的概念,专门用于定义需要进行资格迹(eligibility trace)更新的模型参数。如果我们在程序中将模型参数$\\theta$定义为``brainscale.ETraceParam``,那么``brainscale``的在线学习编译器将对该参数进行具有时序依赖的梯度更新,计算公式为:\n", + "``braintrace.ETraceParam``是与``brainstate.ParamState``相对应的概念,专门用于定义需要进行资格迹(eligibility trace)更新的模型参数。如果我们在程序中将模型参数$\\theta$定义为``braintrace.ETraceParam``,那么``braintrace``的在线学习编译器将对该参数进行具有时序依赖的梯度更新,计算公式为:\n", "\n", "$$\n", "\\nabla_\\theta \\mathcal{L}=\\sum_{t} \\frac{\\partial \\mathcal{L}^{t}}{\\partial \\mathbf{h}^{t}} \\sum_{k=1}^t \\frac{\\partial \\mathbf{h}^t}{\\partial \\boldsymbol{\\theta}^k},\n", @@ -458,13 +460,13 @@ "\n", "其中,$\\boldsymbol{\\theta}^k$是在第$k$时刻使用的权重$\\boldsymbol{\\theta}$。\n", "\n", - "相反,如果我们将模型参数$\\theta$定义为``brainstate.ParamState``,那么``brainscale``的在线学习编译器仅会计算当前时刻损失函数对权重的偏导值,即:\n", + "相反,如果我们将模型参数$\\theta$定义为``brainstate.ParamState``,那么``braintrace``的在线学习编译器仅会计算当前时刻损失函数对权重的偏导值,即:\n", "\n", "$$\n", "\\nabla_\\theta \\mathcal{L}=\\sum_{t} \\frac{\\partial \\mathcal{L}^{t}}{\\partial \\mathbf{h}^{t}} \\frac{\\partial \\mathbf{h}^t}{\\partial \\boldsymbol{\\theta}^t}.\n", "$$\n", "\n", - "这意味着,在``brainscale``的在线学习中,``brainstate.ParamState``被视为不需要进行eligibility trace更新的模型参数,因此失去了对时序依赖信息的梯度计算能力。这样的设计使得模型参数的更新模式更加灵活,从而增加了梯度计算的可定制性。\n" + "这意味着,在``braintrace``的在线学习中,``brainstate.ParamState``被视为不需要进行eligibility trace更新的模型参数,因此失去了对时序依赖信息的梯度计算能力。这样的设计使得模型参数的更新模式更加灵活,从而增加了梯度计算的可定制性。\n" ] }, { @@ -511,10 +513,10 @@ " weight = braintools.init.param(w_init, [self.in_size[-1], self.out_size[-1]], allow_none=False)\n", "\n", " # operation\n", - " op = brainscale.MatMulOp()\n", + " op = braintrace.MatMulOp()\n", "\n", " # 这里是最关键的一步,我们定义了一个 ETraceParam 类,用于描述权重矩阵和操作\n", - " self.weight_op = brainscale.ETraceParam(weight, op)\n", + " self.weight_op = braintrace.ETraceParam(weight, op)\n", "\n", " def update(self, x):\n", " # ETraceParam 的操作\n", @@ -544,9 +546,9 @@ "source": [ "## 4. ``ETraceAlgorithm``:在线学习算法\n", "\n", - "``ETraceAlgorithm``是``brainscale``中的另一个重要概念,它定义了模型的状态更新过程中如何更新eligibility trace,以及定义了模型参数的梯度更新规则。``ETraceAlgorithm``是一个抽象类,专门用于描述``brainscale``内各种形式的在线学习算法。\n", + "``ETraceAlgorithm``是``braintrace``中的另一个重要概念,它定义了模型的状态更新过程中如何更新eligibility trace,以及定义了模型参数的梯度更新规则。``ETraceAlgorithm``是一个抽象类,专门用于描述``braintrace``内各种形式的在线学习算法。\n", "\n", - "``brainscale.ETraceAlgorithm``中提供的算法支持,是基于上面提供的``ETraceState``、``ETraceParam``和``ETraceOp``三个基本概念。``brainscale.ETraceAlgorithm``提供了一种灵活的在线学习编译器。它可以支持使用上述三个概念构建的任意神经网络模型的在线学习。" + "``braintrace.ETraceAlgorithm``中提供的算法支持,是基于上面提供的``ETraceState``、``ETraceParam``和``ETraceOp``三个基本概念。``braintrace.ETraceAlgorithm``提供了一种灵活的在线学习编译器。它可以支持使用上述三个概念构建的任意神经网络模型的在线学习。" ] }, { @@ -557,13 +559,13 @@ }, "source": [ "\n", - "具体来说,目前 ``brainscale`` 支持的在线学习算法有:\n", + "具体来说,目前 ``braintrace`` 支持的在线学习算法有:\n", "\n", - "- [`brainscale.IODimVjpAlgorithm`](../apis/generated/brainscale.IODimVjpAlgorithm.rst) 或 [`brainscale.ES_D_RTRL`](../apis/generated/brainscale.ES_D_RTRL.rst):该算法使用 ES-D-RTRL 算法进行在线学习,支持$O(N)$复杂度的在线梯度计算,适用于大规模脉冲神经网络模型的在线学习。具体算法细节可以参考[我们的论文](https://doi.org/10.1101/2024.09.24.614728)。\n", + "- [`braintrace.IODimVjpAlgorithm`](../apis/generated/braintrace.IODimVjpAlgorithm.rst) 或 [`braintrace.ES_D_RTRL`](../apis/generated/braintrace.ES_D_RTRL.rst):该算法使用 ES-D-RTRL 算法进行在线学习,支持$O(N)$复杂度的在线梯度计算,适用于大规模脉冲神经网络模型的在线学习。具体算法细节可以参考[我们的论文](https://doi.org/10.1101/2024.09.24.614728)。\n", "\n", - "- [`brainscale.ParamDimVjpAlgorithm`](../apis/generated/brainscale.ParamDimVjpAlgorithm.rst) 或 [`brainscale.D_RTRL`](../apis/generated/brainscale.D_RTRL.rst):该算法使用 D-RTRL 算法进行在线学习,支持$O(N^2)$复杂度的在线梯度计算,适用于循环神经网络模型和脉冲神经网络模型的在线学习。具体算法细节可以参考[我们的论文](https://doi.org/10.1101/2024.09.24.614728)。\n", + "- [`braintrace.ParamDimVjpAlgorithm`](../apis/generated/braintrace.ParamDimVjpAlgorithm.rst) 或 [`braintrace.D_RTRL`](../apis/generated/braintrace.D_RTRL.rst):该算法使用 D-RTRL 算法进行在线学习,支持$O(N^2)$复杂度的在线梯度计算,适用于循环神经网络模型和脉冲神经网络模型的在线学习。具体算法细节可以参考[我们的论文](https://doi.org/10.1101/2024.09.24.614728)。\n", "\n", - "- [`brainscale.HybridDimVjpAlgorithm`](../apis/generated/brainscale.HybridDimVjpAlgorithm.rst):该算法选择性地使用 ES-D-RTRL 或 D-RTRL 算法对模型参数进行在线学习。对于卷积层和高度稀疏连接的层,该算法有更大的倾向使用 D-RTRL 算法进行在线学习,以减少在线学习参数更新所需的计算复杂度。\n", + "- [`braintrace.HybridDimVjpAlgorithm`](../apis/generated/braintrace.HybridDimVjpAlgorithm.rst):该算法选择性地使用 ES-D-RTRL 或 D-RTRL 算法对模型参数进行在线学习。对于卷积层和高度稀疏连接的层,该算法有更大的倾向使用 D-RTRL 算法进行在线学习,以减少在线学习参数更新所需的计算复杂度。\n", "\n", "- 未来,我们将会支持更多的在线学习算法,以满足更多的应用场景。\n", "\n" @@ -577,7 +579,7 @@ }, "source": [ "\n", - "在下面的例子中,我们将展示如何使用``brainscale.ETraceAlgorithm``来构建一个支持在线学习的神经网络模型。\n" + "在下面的例子中,我们将展示如何使用``braintrace.ETraceAlgorithm``来构建一个支持在线学习的神经网络模型。\n" ] }, { @@ -597,7 +599,7 @@ " brainstate.nn.init_all_states(model)\n", "\n", " # 将该模型输入到在线学习算法中,以期进行在线学习\n", - " model = brainscale.IODimVjpAlgorithm(model, decay_or_rank=0.99)\n", + " model = braintrace.IODimVjpAlgorithm(model, decay_or_rank=0.99)\n", "\n", " # 根据一个输入数据编译模型的eligibility trace,\n", " # 此后,调用该模型不仅更新模型的状态,还会更新模型的eligibility trace\n", @@ -694,12 +696,12 @@ "source": [ "## 5. 总结\n", "\n", - "总的来说,`brainscale`为在线学习提供了一个完整而优雅的框架体系,其核心概念可以总结为以下几个层次:\n", + "总的来说,`braintrace`为在线学习提供了一个完整而优雅的框架体系,其核心概念可以总结为以下几个层次:\n", "\n", "1. **基础架构层**\n", " - 支持特定结构的动力学模型,将\"动力学\"和\"交互\"严格分离\n", " - 基于`brainstate`生态系统,完全兼容其编程范式\n", - " - 通过`brainscale.nn`模块提供开箱即用的神经网络组件\n", + " - 通过`braintrace.nn`模块提供开箱即用的神经网络组件\n", "\n", "2. **核心概念层**\n", " - `ETraceState`:标记需要进行eligibility trace更新的动力学状态\n", @@ -717,7 +719,7 @@ " - 编译模型生成eligibility trace计算图\n", " - 通过前向传播同时更新模型状态和eligibility trace\n", "\n", - "这个框架的独特之处在于:它将复杂的在线学习算法封装在简洁的接口后,提供灵活的定制化机制,既保持高性能,又确保易用性。同时,它与现有的 `brainstate` 生态系统无缝集成。通过这样的设计,`brainscale` 使构建和训练在线学习神经网络变得直观而高效,为神经计算研究提供了强大的工具。\n" + "这个框架的独特之处在于:它将复杂的在线学习算法封装在简洁的接口后,提供灵活的定制化机制,既保持高性能,又确保易用性。同时,它与现有的 `brainstate` 生态系统无缝集成。通过这样的设计,`braintrace` 使构建和训练在线学习神经网络变得直观而高效,为神经计算研究提供了强大的工具。\n" ] } ], diff --git a/docs/quickstart/rnn_online_learning-en.ipynb b/docs/quickstart/rnn_online_learning-en.ipynb index b481f21..c780abf 100644 --- a/docs/quickstart/rnn_online_learning-en.ipynb +++ b/docs/quickstart/rnn_online_learning-en.ipynb @@ -6,9 +6,9 @@ "# RNN Online Learning\n", "\n", "\n", - "In the chapter on [Key Concepts](./concepts-en.ipynb), we introduced the fundamentals of online learning with `brainscale`. In this section, we will discuss how to implement online learning for Recurrent Neural Networks (RNNs) based on `brainscale`.\n", + "In the chapter on [Key Concepts](./concepts-en.ipynb), we introduced the fundamentals of online learning with `braintrace`. In this section, we will discuss how to implement online learning for Recurrent Neural Networks (RNNs) based on `braintrace`.\n", "\n", - "Rate-based RNNs are more widely used in contemporary deep learning tasks compared to Spiking Neural Networks (SNNs). In these networks, the output of neurons is represented as continuous floating-point values, rather than discrete spikes as in SNNs. The `ParamDimVjpAlgorithm` provided by `brainscale` can be employed very efficiently to support online learning for RNNs." + "Rate-based RNNs are more widely used in contemporary deep learning tasks compared to Spiking Neural Networks (SNNs). In these networks, the output of neurons is represented as continuous floating-point values, rather than discrete spikes as in SNNs. The `ParamDimVjpAlgorithm` provided by `braintrace` can be employed very efficiently to support online learning for RNNs." ], "metadata": { "collapsed": false @@ -19,7 +19,9 @@ "cell_type": "code", "source": [ "import brainstate\n", - "import brainscale" + "import brainpy\n", + "import braintools\n", + "import braintrace" ], "metadata": { "collapsed": false, @@ -160,7 +162,7 @@ "\n", "GRUs simplify the model by reducing the number of gates, allowing for faster training while achieving comparable performance to LSTMs on many tasks.\n", "\n", - "It can be observed that this mathematical formulation of the RNN model perfectly satisfies the separation of \"dynamics\" and \"dynamic interactions\" as discussed in the [Key Concepts](./concepts-en.ipynb). Consequently, the online learning system provided by `brainscale` can effectively support online learning for RNN models.\n" + "It can be observed that this mathematical formulation of the RNN model perfectly satisfies the separation of \"dynamics\" and \"dynamic interactions\" as discussed in the [Key Concepts](./concepts-en.ipynb). Consequently, the online learning system provided by `braintrace` can effectively support online learning for RNN models.\n" ], "metadata": { "collapsed": false @@ -170,17 +172,17 @@ { "cell_type": "markdown", "source": [ - "## 2. RNN Models Supported by `brainscale`\n", + "## 2. RNN Models Supported by `braintrace`\n", "\n", - "`brainscale` does not support online learning for all RNN models. For instance, for the simplest Elman RNNs, the update rule is given by:\n", + "`braintrace` does not support online learning for all RNN models. For instance, for the simplest Elman RNNs, the update rule is given by:\n", "\n", "$$\n", "h_t = f(W_hh_{t-1} + W_xx_t + b_h)\n", "$$\n", "\n", - "`brainscale` does not facilitate online learning for this model. However, for more complex RNN models such as LSTM and GRU, `brainscale` is particularly well-suited for online learning. The primary reason is that the state updates in models like LSTM and GRU are implemented through gating mechanisms, which lead to rich intrinsic dynamics of state variables. As demonstrated in [our paper](https://doi.org/10.1101/2024.09.24.614728), the learning of temporal dependencies in `brainscale` is achieved through the dynamical updates of these element-wise operational state variables.\n", + "`braintrace` does not facilitate online learning for this model. However, for more complex RNN models such as LSTM and GRU, `braintrace` is particularly well-suited for online learning. The primary reason is that the state updates in models like LSTM and GRU are implemented through gating mechanisms, which lead to rich intrinsic dynamics of state variables. As demonstrated in [our paper](https://doi.org/10.1101/2024.09.24.614728), the learning of temporal dependencies in `braintrace` is achieved through the dynamical updates of these element-wise operational state variables.\n", "\n", - "Currently, we can utilize the `ParamDimVjpAlgorithm` for online learning of any RNN model. Let us illustrate how to use `brainscale` for online learning of RNNs with a simple example.\n" + "Currently, we can utilize the `ParamDimVjpAlgorithm` for online learning of any RNN model. Let us illustrate how to use `braintrace` for online learning of RNNs with a simple example.\n" ], "metadata": { "collapsed": false @@ -278,7 +280,7 @@ "source": [ "### 3.2 Definition of the GRU Model\n", "\n", - "We use a GRU model to address the copying task. We can define the GRU model using the `GRUCell` provided by `brainscale`. Below is a simple example of a GRU model definition:" + "We use a GRU model to address the copying task. We can define the GRU model using the `GRUCell` provided by `braintrace`. Below is a simple example of a GRU model definition:" ], "metadata": { "collapsed": false @@ -295,12 +297,12 @@ " # Building a GRU Multilayer Network\n", " layers = []\n", " for _ in range(n_layer):\n", - " layers.append(brainscale.nn.GRUCell(n_in, n_rec))\n", + " layers.append(braintrace.nn.GRUCell(n_in, n_rec))\n", " n_in = n_rec\n", " self.layer = brainstate.nn.Sequential(*layers)\n", "\n", " # Building the Output Layer\n", - " self.readout = brainscale.nn.Linear(n_rec, n_out)\n", + " self.readout = braintrace.nn.Linear(n_rec, n_out)\n", "\n", " def update(self, x):\n", " return self.readout(self.layer(x))" @@ -321,7 +323,7 @@ "source": [ "### 3.3 Online Learning\n", "\n", - "Next, we will first create an abstract `Trainer` class that enables model training for the copying task based on specified parameters. Then, we will implement two concrete trainers: `OnlineTrainer` and `BPTTTrainer`. The `OnlineTrainer` utilizes the online learning algorithm provided by `brainscale`, while the `BPTTTrainer` employs the Backpropagation Through Time (BPTT) algorithm for training." + "Next, we will first create an abstract `Trainer` class that enables model training for the copying task based on specified parameters. Then, we will implement two concrete trainers: `OnlineTrainer` and `BPTTTrainer`. The `OnlineTrainer` utilizes the online learning algorithm provided by `braintrace`, while the `BPTTTrainer` employs the Backpropagation Through Time (BPTT) algorithm for training." ], "metadata": { "collapsed": false @@ -412,7 +414,7 @@ "\n", " # Initialising an online learning model\n", " # Here, we need to use mode to specify that the dataset to be used is one with batch dimensions\n", - " model = brainscale.ParamDimVjpAlgorithm(self.target, mode=brainstate.mixin.Batching())\n", + " model = braintrace.ParamDimVjpAlgorithm(self.target, mode=brainstate.mixin.Batching())\n", "\n", " # Using a sample data compilation for online learning eligibility trace\n", " model.compile_graph(inputs[0])\n", @@ -467,7 +469,7 @@ "cell_type": "markdown", "source": [ "\n", - "In the code above, we define an `OnlineTrainer` class that inherits from the `Trainer` class. Within this class, we utilize the `brainscale.ParamDimVjpAlgorithm` for online learning. At each time step, we compute the gradient of the loss function and use a gradient descent algorithm to update the model parameters." + "In the code above, we define an `OnlineTrainer` class that inherits from the `Trainer` class. Within this class, we utilize the `braintrace.ParamDimVjpAlgorithm` for online learning. At each time step, we compute the gradient of the loss function and use a gradient descent algorithm to update the model parameters." ], "metadata": { "collapsed": false @@ -569,17 +571,17 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/1000 [00:00 0.5\n", " )\n", ")" @@ -207,12 +207,12 @@ "source": [ "# Apply a function to the weights\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp(\n", + " braintrace.MatMulOp(\n", " weight_fn=jnp.abs # Ensures weights are positive\n", " )\n", ")" @@ -265,12 +265,12 @@ "cell_type": "code", "source": [ "# 同时使用掩码和权重函数\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp(\n", + " braintrace.MatMulOp(\n", " weight_fn=jnp.abs,\n", " weight_mask=brainstate.random.rand(4, 5) > 0.5\n", " )\n", @@ -311,9 +311,9 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### `brainscale.ConvOp`: The Convolution Operator\n", + "### `braintrace.ConvOp`: The Convolution Operator\n", "\n", - "The [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst) provides general-purpose convolution operations suitable for models like CNNs.\n", + "The [`braintrace.ConvOp`](../apis/generated/braintrace.ConvOp.rst) provides general-purpose convolution operations suitable for models like CNNs.\n", "\n", "**Dimensionality Support**:\n", "A key feature of `ConvOp` is its ability to adapt to different dimensions. By specifying the `xinfo` parameter (a `jax.ShapeDtypeStruct` object), it can automatically infer and execute 1D, 2D, or 3D convolutions:\n", @@ -342,12 +342,12 @@ "cell_type": "code", "source": [ "# Example of a 2D convolution\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32), # (height, width, channels)\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -417,12 +417,12 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -490,12 +490,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -561,12 +561,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=(1, 1),\n", " padding='SAME',\n", @@ -616,11 +616,11 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### `brainscale.SpMatMulOp`: The Sparse Matrix Multiplication Operator\n", + "### `braintrace.SpMatMulOp`: The Sparse Matrix Multiplication Operator\n", "\n", - "The [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst) operator supports sparse matrix multiplication operations, suitable for scenarios like Graph Neural Networks (GNNs). It takes feature maps $x$ and parameters $w$ as input, and outputs the sparse matrix multiplication result $y$.\n", + "The [`braintrace.SpMatMulOp`](../apis/generated/braintrace.SpMatMulOp.rst) operator supports sparse matrix multiplication operations, suitable for scenarios like Graph Neural Networks (GNNs). It takes feature maps $x$ and parameters $w$ as input, and outputs the sparse matrix multiplication result $y$.\n", "\n", - "`brainscale.SpMatMulOp` performs similar operations to `brainscale.MatMulOp` for matrix multiplication:\n", + "`braintrace.SpMatMulOp` performs similar operations to `braintrace.MatMulOp` for matrix multiplication:\n", "\n", "$$y = x @ \\text{param['weight']} + \\text{param['bias']}$$\n", "\n", @@ -630,7 +630,7 @@ "- ``brainevent.CSC``: Compressed Sparse Column matrix.\n", "- ``brainevent.COO``: Coordinate Format sparse matrix.\n", "\n", - "`brainscale.SpMatMulOp` supports the following operations:\n", + "`braintrace.SpMatMulOp` supports the following operations:\n", "\n", "**1. Standard matrix multiplication**:\n", "\n", @@ -669,9 +669,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {'weight': brainstate.random.rand(100)},\n", - " brainscale.SpMatMulOp(csr)\n", + " braintrace.SpMatMulOp(csr)\n", ")" ], "id": "39d0558b3f089277", @@ -719,9 +719,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {'weight': brainstate.random.rand(100)},\n", - " brainscale.SpMatMulOp(csr, weight_fn=jnp.abs)\n", + " braintrace.SpMatMulOp(csr, weight_fn=jnp.abs)\n", ")" ], "id": "b596da0930657089", @@ -754,9 +754,9 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### `brainscale.ElemWiseOp`: Element-wise Operation Operator\n", + "### `braintrace.ElemWiseOp`: Element-wise Operation Operator\n", "\n", - "[`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst) provides a concise way to apply element-wise function transformations to parameters. It doesn't directly process pre-synaptic input $x$, but operates directly on its own parameters $w$.\n", + "[`braintrace.ElemWiseOp`](../apis/generated/braintrace.ElemWiseOp.rst) provides a concise way to apply element-wise function transformations to parameters. It doesn't directly process pre-synaptic input $x$, but operates directly on its own parameters $w$.\n", "\n", "**Core Operation**:\n", "\n", @@ -777,9 +777,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(jnp.abs) # Absolute value operation\n", + " braintrace.ElemWiseOp(jnp.abs) # Absolute value operation\n", ")" ], "id": "81c5e77aa9ed9030", @@ -814,9 +814,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(jnp.exp) # Exponential operation\n", + " braintrace.ElemWiseOp(jnp.exp) # Exponential operation\n", ")" ], "id": "4bc5020e648a3bb2", @@ -853,9 +853,9 @@ "source": [ "# Using custom lambda function\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(lambda x: x ** 2 + 1.) # Custom function\n", + " braintrace.ElemWiseOp(lambda x: x ** 2 + 1.) # Custom function\n", ")" ], "id": "c8c3c04eabe07a8", @@ -887,7 +887,7 @@ "source": [ "## Custom Eligibility Trace Operators\n", "\n", - "Although `brainscale` provides a comprehensive suite of built-in operators, research and applications often require exploration of novel neural network layers or synaptic plasticity rules. For this purpose, `brainscale` allows users to easily create custom operators by inheriting from the `brainscale.ETraceOp` base class.\n", + "Although `braintrace` provides a comprehensive suite of built-in operators, research and applications often require exploration of novel neural network layers or synaptic plasticity rules. For this purpose, `braintrace` allows users to easily create custom operators by inheriting from the `braintrace.ETraceOp` base class.\n", "\n", "Customizing an operator involves understanding and implementing two core methods: `xw_to_y` and `yw_to_w`.\n", "\n", @@ -928,7 +928,7 @@ }, "cell_type": "code", "source": [ - "class CustomizedMatMul(brainscale.ETraceOp):\n", + "class CustomizedMatMul(braintrace.ETraceOp):\n", " \"\"\"\n", " A custom matrix multiplication eligibility trace operator.\n", " It implements the computation y = x @ w['weight'] + w['bias'].\n", @@ -971,7 +971,7 @@ "\n", "### Using Custom Operators\n", "\n", - "Once defined, `CustomizedMatMul` can be used like any built-in operator, combined with `ETraceParam`, and seamlessly integrated into `brainscale`'s computational graph.\n" + "Once defined, `CustomizedMatMul` can be used like any built-in operator, combined with `ETraceParam`, and seamlessly integrated into `braintrace`'s computational graph.\n" ], "id": "3dc38faec8488857" }, @@ -988,7 +988,7 @@ "my_op = CustomizedMatMul()\n", "\n", "# 2. Use ETraceParam to associate operator with specific parameters\n", - "param = brainscale.ETraceParam(\n", + "param = braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5), # D_in=4, D_out=5\n", " 'bias': brainstate.random.rand(5)\n", @@ -1000,7 +1000,7 @@ "# Create some mock input data\n", "dummy_input = brainstate.random.rand(1, 4) # Batch=1, D_in=4\n", "\n", - "# brainscale's runner will automatically call op.xw_to_y(dummy_input, param.value)\n", + "# braintrace's runner will automatically call op.xw_to_y(dummy_input, param.value)\n", "# We can manually call to verify\n", "output = my_op.xw_to_y(dummy_input, param.value)\n", "\n", diff --git a/docs/tutorial/etraceop-zh.ipynb b/docs/tutorial/etraceop-zh.ipynb index ccbd69c..08f284d 100644 --- a/docs/tutorial/etraceop-zh.ipynb +++ b/docs/tutorial/etraceop-zh.ipynb @@ -10,7 +10,7 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "在`brainscale`框架中,资格迹算子 (`ETraceOp`) 扮演着连接神经网络中神经元群体、定义突触交互的核心角色。它的主要职责是根据模型的输入(突触前活动)和参数(如突触权重),精确计算出突触后电流。更重要的是,`ETraceOp` 原生支持基于资格迹(Eligibility Trace)的学习机制,这是一种模拟生物神经系统中时间信用分配(temporal credit assignment)的关键过程,使得模型能够根据延迟的奖励或误差信号来更新连接权重。\n", + "在`braintrace`框架中,资格迹算子 (`ETraceOp`) 扮演着连接神经网络中神经元群体、定义突触交互的核心角色。它的主要职责是根据模型的输入(突触前活动)和参数(如突触权重),精确计算出突触后电流。更重要的是,`ETraceOp` 原生支持基于资格迹(Eligibility Trace)的学习机制,这是一种模拟生物神经系统中时间信用分配(temporal credit assignment)的关键过程,使得模型能够根据延迟的奖励或误差信号来更新连接权重。\n", "\n", "`ETraceOp` 的设计哲学是将计算逻辑(算子本身)与可训练参数(`ETraceParam`)解耦,从而提供了极大的灵活性和可扩展性。" ], @@ -30,7 +30,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "import brainscale" + "import braintrace" ], "id": "9ac44f0bb1d51cd9", "outputs": [], @@ -42,18 +42,18 @@ "source": [ "## 内置的资格迹算子\n", "\n", - "`brainscale` 提供了一系列功能强大且预先配置好的资格迹算子,能够满足绝大多数常见的神经网络建模需求。这些算子与模型参数容器 `brainscale.ETraceParam` 配合使用,构成了神经网络的构建模块。\n", + "`braintrace` 提供了一系列功能强大且预先配置好的资格迹算子,能够满足绝大多数常见的神经网络建模需求。这些算子与模型参数容器 `braintrace.ETraceParam` 配合使用,构成了神经网络的构建模块。\n", "\n", "主要内置算子包括:\n", "\n", "\n", - "- [`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst): 实现标准的矩阵乘法,是构建全连接层(Dense Layer)的基础。\n", - "- [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst): 实现卷积操作,支持1D、2D和3D卷积,是构建卷积神经网络(CNN)的核心。\n", - "- [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst): 专为稀疏连接设计,实现了稀疏矩阵乘法,在图神经网络(GNN)和需要高效表示大规模稀疏连接的生物可塑性模型中尤为重要。\n", - "- [`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst): 执行元素级别的数学运算,常用于实现激活函数、缩放或其他自定义的逐元素变换。\n", - "- [`brainscale.LoraOp`](../apis/generated/brainscale.LoraOp.rst): 实现低秩适应(Low-Rank Adaptation)技术,这是一种高效微调大型预训练模型的方法。\n", + "- [`braintrace.MatMulOp`](../apis/generated/braintrace.MatMulOp.rst): 实现标准的矩阵乘法,是构建全连接层(Dense Layer)的基础。\n", + "- [`braintrace.ConvOp`](../apis/generated/braintrace.ConvOp.rst): 实现卷积操作,支持1D、2D和3D卷积,是构建卷积神经网络(CNN)的核心。\n", + "- [`braintrace.SpMatMulOp`](../apis/generated/braintrace.SpMatMulOp.rst): 专为稀疏连接设计,实现了稀疏矩阵乘法,在图神经网络(GNN)和需要高效表示大规模稀疏连接的生物可塑性模型中尤为重要。\n", + "- [`braintrace.ElemWiseOp`](../apis/generated/braintrace.ElemWiseOp.rst): 执行元素级别的数学运算,常用于实现激活函数、缩放或其他自定义的逐元素变换。\n", + "- [`braintrace.LoraOp`](../apis/generated/braintrace.LoraOp.rst): 实现低秩适应(Low-Rank Adaptation)技术,这是一种高效微调大型预训练模型的方法。\n", "\n", - "这些资格迹算子通常需要配合模型参数`brainscale.ETraceParam`一起使用。" + "这些资格迹算子通常需要配合模型参数`braintrace.ETraceParam`一起使用。" ], "id": "de5e69fd2fd011fc" }, @@ -61,9 +61,9 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### `brainscale.MatMulOp` 矩阵乘法算子\n", + "### `braintrace.MatMulOp` 矩阵乘法算子\n", "\n", - "[`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst) 是最基础的算子,支持矩阵乘法操作,适用于全连接层等场景。\n", + "[`braintrace.MatMulOp`](../apis/generated/braintrace.MatMulOp.rst) 是最基础的算子,支持矩阵乘法操作,适用于全连接层等场景。\n", "\n", "**基本操作**:\n", "- 输入:矩阵 $x \\in \\mathbb{R}^{B \\times D_{in}}$\n", @@ -89,12 +89,12 @@ "source": [ "# 标准矩阵乘法\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp()\n", + " braintrace.MatMulOp()\n", ")" ], "id": "7cfcfd6102942c5c", @@ -146,12 +146,12 @@ "source": [ "# 带掩码的矩阵乘法(实现稀疏连接)\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp(\n", + " braintrace.MatMulOp(\n", " weight_mask=brainstate.random.rand(4, 5) > 0.5\n", " )\n", ")" @@ -208,12 +208,12 @@ "source": [ "# 对权重应用函数变换\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp(\n", + " braintrace.MatMulOp(\n", " weight_fn=jnp.abs # 确保权重为正\n", " )\n", ")" @@ -266,12 +266,12 @@ "cell_type": "code", "source": [ "# 同时使用掩码和权重函数\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5),\n", " 'bias': brainstate.random.rand(5)\n", " },\n", - " brainscale.MatMulOp(\n", + " braintrace.MatMulOp(\n", " weight_fn=jnp.abs,\n", " weight_mask=brainstate.random.rand(4, 5) > 0.5\n", " )\n", @@ -312,9 +312,9 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### ``brainscale.ConvOp`` 卷积算子\n", + "### ``braintrace.ConvOp`` 卷积算子\n", "\n", - "[`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst) 算子支持一般性的卷积操作,适用于卷积神经网络(CNN)等场景。它的输入是特征图$x$和参数$w$,输出是卷积结果$y$。\n", + "[`braintrace.ConvOp`](../apis/generated/braintrace.ConvOp.rst) 算子支持一般性的卷积操作,适用于卷积神经网络(CNN)等场景。它的输入是特征图$x$和参数$w$,输出是卷积结果$y$。\n", "\n", "- 输入$x$是一个矩阵。\n", "- 参数$w$是一个字典,涵盖了权重矩阵字段``weight``和偏置向量字段``bias``。这个算子可以用于实现全连接层的前向传播。\n", @@ -322,7 +322,7 @@ "\n", "**维度支持**:\n", "\n", - "`brainscale.ConvOp`支持1D、2D、3D卷积等多种形式的卷积操作。通过 `xinfo` 参数(一个`jax.ShapeDtypeStruct`对象),它可以自动推断并执行1D、2D或3D卷积。比如,\n", + "`braintrace.ConvOp`支持1D、2D、3D卷积等多种形式的卷积操作。通过 `xinfo` 参数(一个`jax.ShapeDtypeStruct`对象),它可以自动推断并执行1D、2D或3D卷积。比如,\n", "\n", "- **1D卷积**:当 `xinfo=jax.ShapeDtypeStruct((32, 3), jnp.float32)` 时,表示输入是一个形状为 `(32, 3)` 的2维张量(通道数为3,长度均为32),此时卷积是1D卷积。\n", "- **2D卷积**:当 `xinfo=jax.ShapeDtypeStruct((32, 32, 3), jnp.float32)` 时,表示输入是一个形状为 `(32, 32, 3)` 的3维张量(通道数为3,高度和宽度均为32),此时卷积是2D卷积。\n", @@ -353,12 +353,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32), # (height, width, channels)\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -423,12 +423,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -496,12 +496,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -567,12 +567,12 @@ "cell_type": "code", "source": [ "# 以2D卷积为例\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(3, 3),\n", " 'bias': jnp.zeros(16)\n", " },\n", - " brainscale.ConvOp(\n", + " braintrace.ConvOp(\n", " xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),\n", " window_strides=[1, 1],\n", " padding='SAME',\n", @@ -625,11 +625,11 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### ``brainscale.SpMatMulOp`` 稀疏矩阵乘法算子\n", + "### ``braintrace.SpMatMulOp`` 稀疏矩阵乘法算子\n", "\n", - "[`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst) 算子支持稀疏矩阵乘法操作,适用于图神经网络(GNN)等场景。它的输入是特征图$x$和参数$w$,输出是稀疏矩阵乘法结果$y$。\n", + "[`braintrace.SpMatMulOp`](../apis/generated/braintrace.SpMatMulOp.rst) 算子支持稀疏矩阵乘法操作,适用于图神经网络(GNN)等场景。它的输入是特征图$x$和参数$w$,输出是稀疏矩阵乘法结果$y$。\n", "\n", - "`brainscale.SpMatMulOp` 与 `brainscale.MatMulOp` 做类似的操作,用于进行矩阵乘法操作:\n", + "`braintrace.SpMatMulOp` 与 `braintrace.MatMulOp` 做类似的操作,用于进行矩阵乘法操作:\n", "\n", "$$\n", "y = x @ \\mathrm{param['weight']} + \\mathrm{param['bias']}\n", @@ -647,7 +647,7 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "`brainscale.SpMatMulOp`支持如下操作:\n", + "`braintrace.SpMatMulOp`支持如下操作:\n", "\n", "**1. 标准矩阵乘法:**\n", "\n", @@ -684,9 +684,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {'weight': brainstate.random.rand(100)},\n", - " brainscale.SpMatMulOp(csr)\n", + " braintrace.SpMatMulOp(csr)\n", ")" ], "id": "39d0558b3f089277", @@ -734,9 +734,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " {'weight': brainstate.random.rand(100)},\n", - " brainscale.SpMatMulOp(csr, weight_fn=jnp.abs)\n", + " braintrace.SpMatMulOp(csr, weight_fn=jnp.abs)\n", ")" ], "id": "b596da0930657089", @@ -769,9 +769,9 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "### ``brainscale.ElemWiseOp`` 元素级操作算子\n", + "### ``braintrace.ElemWiseOp`` 元素级操作算子\n", "\n", - "[`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst) 提供了一种简洁的方式来对参数进行逐元素的函数变换。它不直接处理来自突触前的输入 $x$,而是直接作用于其自身的参数 $w$。\n", + "[`braintrace.ElemWiseOp`](../apis/generated/braintrace.ElemWiseOp.rst) 提供了一种简洁的方式来对参数进行逐元素的函数变换。它不直接处理来自突触前的输入 $x$,而是直接作用于其自身的参数 $w$。\n", "\n", "**核心运算**:\n", "$$y = f(w)$$\n", @@ -790,9 +790,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(jnp.abs) # 绝对值操作\n", + " braintrace.ElemWiseOp(jnp.abs) # 绝对值操作\n", ")" ], "id": "81c5e77aa9ed9030", @@ -827,9 +827,9 @@ }, "cell_type": "code", "source": [ - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(jnp.exp) # 指数操作\n", + " braintrace.ElemWiseOp(jnp.exp) # 指数操作\n", ")" ], "id": "4bc5020e648a3bb2", @@ -866,9 +866,9 @@ "source": [ "# 使用自定义的lambda函数\n", "\n", - "brainscale.ETraceParam(\n", + "braintrace.ETraceParam(\n", " brainstate.random.rand(4),\n", - " brainscale.ElemWiseOp(lambda x: x ** 2 + 1.) # 自定义函数\n", + " braintrace.ElemWiseOp(lambda x: x ** 2 + 1.) # 自定义函数\n", ")" ], "id": "c8c3c04eabe07a8", @@ -908,7 +908,7 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "尽管`brainscale`提供了一套完备的内置算子,但研究和应用中常常需要探索新颖的神经网络层或突触可塑性规则。为此,`brainscale`允许用户通过继承`brainscale.ETraceOp`基类来轻松创建自定义算子。\n", + "尽管`braintrace`提供了一套完备的内置算子,但研究和应用中常常需要探索新颖的神经网络层或突触可塑性规则。为此,`braintrace`允许用户通过继承`braintrace.ETraceOp`基类来轻松创建自定义算子。\n", "\n", "自定义一个算子,关键在于理解并实现两个核心方法:`xw_to_y`和`yw_to_w`。\n", "\n", @@ -949,7 +949,7 @@ }, "cell_type": "code", "source": [ - "class CustomizedMatMul(brainscale.ETraceOp):\n", + "class CustomizedMatMul(braintrace.ETraceOp):\n", " \"\"\"\n", " 一个自定义的矩阵乘法资格迹算子。\n", " 它实现了 y = x @ w['weight'] + w['bias'] 的计算。\n", @@ -993,7 +993,7 @@ "\n", "### 使用自定义算子\n", "\n", - "定义好之后,`CustomizedMatMul`可以像任何内置算子一样,与`ETraceParam`结合使用,并无缝集成到`brainscale`的计算图中。" + "定义好之后,`CustomizedMatMul`可以像任何内置算子一样,与`ETraceParam`结合使用,并无缝集成到`braintrace`的计算图中。" ], "id": "3dc38faec8488857" }, @@ -1010,7 +1010,7 @@ "my_op = CustomizedMatMul()\n", "\n", "# 2. 使用ETraceParam将算子与具体参数关联\n", - "param = brainscale.ETraceParam(\n", + "param = braintrace.ETraceParam(\n", " {\n", " 'weight': brainstate.random.rand(4, 5), # D_in=4, D_out=5\n", " 'bias': brainstate.random.rand(5)\n", @@ -1022,7 +1022,7 @@ "# 创建一些模拟的输入数据\n", "dummy_input = brainstate.random.rand(1, 4) # Batch=1, D_in=4\n", "\n", - "# brainscale的运行器会自动调用 op.xw_to_y(dummy_input, param.value)\n", + "# braintrace的运行器会自动调用 op.xw_to_y(dummy_input, param.value)\n", "# 我们可以手动调用来验证\n", "output = my_op.xw_to_y(dummy_input, param.value)\n", "\n", diff --git a/docs/tutorial/etracestate-en.ipynb b/docs/tutorial/etracestate-en.ipynb index dd553ad..13c7cf8 100644 --- a/docs/tutorial/etracestate-en.ipynb +++ b/docs/tutorial/etracestate-en.ipynb @@ -6,7 +6,7 @@ "source": [ "# `ETraceState`: Online Learning State Management\n", "\n", - "In the `brainscale` framework, the `ETraceState` class family provides powerful state management functionality specifically designed for implementing **eligibility trace-based online learning mechanisms**. Eligibility traces are important concepts in reinforcement learning and neural network training, allowing systems to track and update the historical activity of neurons and synapses, thereby enabling more efficient learning algorithms.\n", + "In the `braintrace` framework, the `ETraceState` class family provides powerful state management functionality specifically designed for implementing **eligibility trace-based online learning mechanisms**. Eligibility traces are important concepts in reinforcement learning and neural network training, allowing systems to track and update the historical activity of neurons and synapses, thereby enabling more efficient learning algorithms.\n", "\n", "## Core Features\n", "\n", @@ -28,10 +28,11 @@ } }, "source": [ - "import brainscale\n", + "import braintrace\n", "import brainstate\n", "import brainunit as u\n", - "import jax.numpy as jnp" + "import jax.numpy as jnp\n", + "import brainpy" ], "outputs": [], "execution_count": 1 @@ -196,13 +197,13 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "## `brainscale.ETraceGroupState` Class: Group State Management\n", + "## `braintrace.ETraceGroupState` Class: Group State Management\n", "\n", - "The [`brainscale.ETraceGroupState`](../apis/generated/brainscale.ETraceGroupState.rst) class is specifically designed for defining multiple states of neuron or synapse populations. It is a subclass of the `brainstate.HiddenState` class, inheriting all its attributes and methods.\n", + "The [`braintrace.ETraceGroupState`](../apis/generated/braintrace.ETraceGroupState.rst) class is specifically designed for defining multiple states of neuron or synapse populations. It is a subclass of the `brainstate.HiddenState` class, inheriting all its attributes and methods.\n", "\n", - "In multi-compartment neuron models, each variable represents the state of multiple compartments, such as membrane potential. If each compartment's membrane potential were defined using a separate `brainstate.HiddenState` class, then multiple state variables would need to be defined in multi-compartment neuron models, leading to verbose and difficult-to-maintain code. However, using the `brainscale.ETraceGroupState` class allows multiple state variables to be combined together, simplifying code structure.\n", + "In multi-compartment neuron models, each variable represents the state of multiple compartments, such as membrane potential. If each compartment's membrane potential were defined using a separate `brainstate.HiddenState` class, then multiple state variables would need to be defined in multi-compartment neuron models, leading to verbose and difficult-to-maintain code. However, using the `braintrace.ETraceGroupState` class allows multiple state variables to be combined together, simplifying code structure.\n", "\n", - "In the following example, we will use the `brainscale.ETraceGroupState` class to define state variables for a three-compartment neuron model.\n", + "In the following example, we will use the `braintrace.ETraceGroupState` class to define state variables for a three-compartment neuron model.\n", "\n", "### Multi-Compartment Neuron Modeling" ], @@ -274,7 +275,7 @@ " super().__init__(pop_size, morphology=morphology)\n", "\n", " def init_state(self, *args, **kwargs):\n", - " self.V = brainscale.ETraceGroupState(jnp.zeros(self.varshape) * u.mV)" + " self.V = braintrace.ETraceGroupState(jnp.zeros(self.varshape) * u.mV)" ], "id": "8aa6a12a956ff868", "outputs": [], @@ -299,7 +300,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "Each `brainscale.ETraceGroupState` instance represents a state variable containing state information for multiple compartments. In this example, we only defined the membrane potential $V$ state variable, but actually more state variables can be defined, such as adaptation currents $I_j$, etc. Each state variable can contain state information for multiple compartments, making it usable in multi-compartment neuron models.", + "source": "Each `braintrace.ETraceGroupState` instance represents a state variable containing state information for multiple compartments. In this example, we only defined the membrane potential $V$ state variable, but actually more state variables can be defined, such as adaptation currents $I_j$, etc. Each state variable can contain state information for multiple compartments, making it usable in multi-compartment neuron models.", "id": "e1db746ac67a9c1c" }, { @@ -345,11 +346,11 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "## `brainscale.ETraceTreeState` Class: Tree Structure State\n", + "## `braintrace.ETraceTreeState` Class: Tree Structure State\n", "\n", - "[`brainscale.ETraceTreeState`](../apis/generated/brainscale.ETraceTreeState.rst) provides the most flexible state management solution, supporting **PyTree tree structures**, suitable for neural network models with complex hierarchical relationships. It is a subclass of the `brainstate.HiddenState` class, inheriting all its attributes and methods.\n", + "[`braintrace.ETraceTreeState`](../apis/generated/braintrace.ETraceTreeState.rst) provides the most flexible state management solution, supporting **PyTree tree structures**, suitable for neural network models with complex hierarchical relationships. It is a subclass of the `brainstate.HiddenState` class, inheriting all its attributes and methods.\n", "\n", - "The following uses the GIF model as an example to demonstrate how to use the `brainscale.ETraceTreeState` class to define tree-structured state variables.\n", + "The following uses the GIF model as an example to demonstrate how to use the `braintrace.ETraceTreeState` class to define tree-structured state variables.\n", "\n", "### Advanced Application Example" ], @@ -366,7 +367,7 @@ "source": [ "class GIF_tree(brainpy.state.Neuron):\n", " def init_state(self, *args, **kwargs):\n", - " self.state = brainscale.ETraceTreeState(\n", + " self.state = braintrace.ETraceTreeState(\n", " {\n", " 'I1': jnp.zeros(self.varshape) * u.mA,\n", " 'I2': jnp.zeros(self.varshape) * u.mA,\n", @@ -398,7 +399,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "Each `brainscale.ETraceTreeState` instance represents a tree-structured state variable containing multiple sub-state variables. In this example, we defined four state variables: $I_1$, $I_2$, $V$, and $V_{th}$, which are organized into a tree structure.", + "source": "Each `braintrace.ETraceTreeState` instance represents a tree-structured state variable containing multiple sub-state variables. In this example, we defined four state variables: $I_1$, $I_2$, $V$, and $V_{th}$, which are organized into a tree structure.", "id": "14a7e2628981ecfb" }, { @@ -441,7 +442,7 @@ "source": [ "## Summary\n", "\n", - "The `ETraceState` class family in `brainscale` provides powerful and flexible state management solutions for neural network modeling:\n", + "The `ETraceState` class family in `braintrace` provides powerful and flexible state management solutions for neural network modeling:\n", "\n", "| Type | Applicable Scenarios | Advantages | Typical Applications |\n", "|------|---------------------|------------|---------------------|\n", diff --git a/docs/tutorial/etracestate-zh.ipynb b/docs/tutorial/etracestate-zh.ipynb index 3ab24b8..35b6baf 100644 --- a/docs/tutorial/etracestate-zh.ipynb +++ b/docs/tutorial/etracestate-zh.ipynb @@ -7,7 +7,7 @@ "# `ETraceState`: 在线学习状态管理\n", "\n", "\n", - "在 `brainscale` 框架中,`ETraceState` 类系列提供了强大的状态管理功能,专门用于实现**资格迹(Eligibility Trace)在线学习机制**。资格迹是强化学习和神经网络训练中的重要概念,它允许系统追踪和更新神经元及突触的历史活动,从而实现更高效的学习算法。" + "在 `braintrace` 框架中,`ETraceState` 类系列提供了强大的状态管理功能,专门用于实现**资格迹(Eligibility Trace)在线学习机制**。资格迹是强化学习和神经网络训练中的重要概念,它允许系统追踪和更新神经元及突触的历史活动,从而实现更高效的学习算法。" ], "id": "6908d888efacaf39" }, @@ -38,8 +38,8 @@ "import brainstate\n", "import brainunit as u\n", "import jax.numpy as jnp\n", - "\n", - "import brainscale" + "import brainpy\n", + "import braintrace" ], "outputs": [], "execution_count": 21 @@ -206,13 +206,13 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "## `brainscale.ETraceGroupState` 类:群组状态管理\n", + "## `braintrace.ETraceGroupState` 类:群组状态管理\n", "\n", - "[`brainscale.ETraceGroupState`](../apis/generated/brainscale.ETraceGroupState.rst) 类专门用于定义神经元或突触群体的多个状态。它是`brainstate.HiddenState`类的一个子类,继承了其所有属性和方法。\n", + "[`braintrace.ETraceGroupState`](../apis/generated/braintrace.ETraceGroupState.rst) 类专门用于定义神经元或突触群体的多个状态。它是`brainstate.HiddenState`类的一个子类,继承了其所有属性和方法。\n", "\n", - "在多房室神经元模型(multi-compartment neuron model)中,每一个变量表示多个房室的状态,比如膜电位。如果将每个房室的膜电位使用一个`brainstate.HiddenState`类来定义,那么在多房室神经元模型中就需要定义多个状态变量,这样会导致代码冗长且难以维护。然而,使用`brainscale.ETraceGroupState`类可以将多个状态变量组合在一起,简化代码结构。\n", + "在多房室神经元模型(multi-compartment neuron model)中,每一个变量表示多个房室的状态,比如膜电位。如果将每个房室的膜电位使用一个`brainstate.HiddenState`类来定义,那么在多房室神经元模型中就需要定义多个状态变量,这样会导致代码冗长且难以维护。然而,使用`braintrace.ETraceGroupState`类可以将多个状态变量组合在一起,简化代码结构。\n", "\n", - "在以下示例中,我们将使用`brainscale.ETraceGroupState`类来定义一个三房室神经元模型的状态变量。\n", + "在以下示例中,我们将使用`braintrace.ETraceGroupState`类来定义一个三房室神经元模型的状态变量。\n", "\n", "### 多房室神经元建模" ], @@ -284,7 +284,7 @@ " super().__init__(pop_size, morphology=morphology)\n", "\n", " def init_state(self, *args, **kwargs):\n", - " self.V = brainscale.ETraceGroupState(jnp.zeros(self.varshape) * u.mV)" + " self.V = braintrace.ETraceGroupState(jnp.zeros(self.varshape) * u.mV)" ], "id": "d3e4d3e813a3d362", "outputs": [], @@ -309,7 +309,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "每个`brainscale.ETraceGroupState`实例都代表一个状态变量,包含了多个房室的状态信息。在这个例子中,我们只定义了膜电位$V$的状态变量,但实际上可以定义更多的状态变量,比如适应性电流$I_j$等。每个状态变量可以包含多个房室的状态信息,这样就可以在多房室神经元模型中使用。", + "source": "每个`braintrace.ETraceGroupState`实例都代表一个状态变量,包含了多个房室的状态信息。在这个例子中,我们只定义了膜电位$V$的状态变量,但实际上可以定义更多的状态变量,比如适应性电流$I_j$等。每个状态变量可以包含多个房室的状态信息,这样就可以在多房室神经元模型中使用。", "id": "4c0e08a7f685bf" }, { @@ -355,13 +355,13 @@ "metadata": {}, "cell_type": "markdown", "source": [ - "## `brainscale.ETraceTreeState` 类:树状结构状态\n", + "## `braintrace.ETraceTreeState` 类:树状结构状态\n", "\n", "\n", - "[`brainscale.ETraceTreeState`](../apis/generated/brainscale.ETraceTreeState.rst) 提供了最灵活的状态管理方案,支持 **PyTree 树状结构**,适用于具有复杂层次关系的神经网络模型。\n", + "[`braintrace.ETraceTreeState`](../apis/generated/braintrace.ETraceTreeState.rst) 提供了最灵活的状态管理方案,支持 **PyTree 树状结构**,适用于具有复杂层次关系的神经网络模型。\n", "它是`brainstate.HiddenState`类的一个子类,继承了其所有属性和方法。\n", "\n", - "以下以GIF模型为例,展示如何使用`brainscale.ETraceTreeState`类来定义一个树状结构的状态变量。\n", + "以下以GIF模型为例,展示如何使用`braintrace.ETraceTreeState`类来定义一个树状结构的状态变量。\n", "\n", "### 高级应用示例" ], @@ -378,7 +378,7 @@ "source": [ "class GIF_tree(brainpy.state.Neuron):\n", " def init_state(self, *args, **kwargs):\n", - " self.state = brainscale.ETraceTreeState(\n", + " self.state = braintrace.ETraceTreeState(\n", " {\n", " 'I1': jnp.zeros(self.varshape) * u.mA,\n", " 'I2': jnp.zeros(self.varshape) * u.mA,\n", @@ -410,7 +410,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "每个`brainscale.ETraceTreeState`实例都代表一个树状结构的状态变量,包含了多个子状态变量。在这个例子中,我们定义了$I_1$、$I_2$、$V$和$V_{th}$四个状态变量,它们被组织成一个树状结构。", + "source": "每个`braintrace.ETraceTreeState`实例都代表一个树状结构的状态变量,包含了多个子状态变量。在这个例子中,我们定义了$I_1$、$I_2$、$V$和$V_{th}$四个状态变量,它们被组织成一个树状结构。", "id": "570d6a5aedd18ec4" }, { @@ -453,7 +453,7 @@ "source": [ "## 总结\n", "\n", - "`brainscale` 的 `ETraceState` 类系列为神经网络建模提供了强大而灵活的状态管理解决方案:\n", + "`braintrace` 的 `ETraceState` 类系列为神经网络建模提供了强大而灵活的状态管理解决方案:\n", "\n", "| 类型 | 适用场景 | 优势 | 典型应用 |\n", "|------|----------|------|----------|\n", diff --git a/docs/tutorial/show_graph-en.ipynb b/docs/tutorial/show_graph-en.ipynb index b30212d..19e23fb 100644 --- a/docs/tutorial/show_graph-en.ipynb +++ b/docs/tutorial/show_graph-en.ipynb @@ -9,7 +9,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "`brainscale` uses intermediate representation (IR) analysis to extract the dependencies between neuron states, synaptic connections, and model parameters. By calling the `.show_graph()` method, users can visualize the compiled computation graph, providing deeper insights into the computational structure and interdependencies within the neural model.", + "source": "`braintrace` uses intermediate representation (IR) analysis to extract the dependencies between neuron states, synaptic connections, and model parameters. By calling the `.show_graph()` method, users can visualize the compiled computation graph, providing deeper insights into the computational structure and interdependencies within the neural model.", "id": "f653b562490f5dad" }, { @@ -23,8 +23,10 @@ "source": [ "import brainstate\n", "import brainunit as u\n", - "import brainscale\n", - "import jax" + "import braintrace\n", + "import jax\n", + "import brainpy\n", + "import braintools" ], "id": "d1b039ee6f7c9256", "outputs": [], @@ -68,14 +70,14 @@ " ff_init = braintools.init.KaimingNormal(ff_scale, unit=u.mV)\n", " w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0)\n", " self.syn = brainpy.state.DeltaProj(\n", - " comm=brainscale.nn.Linear(\n", + " comm=braintrace.nn.Linear(\n", " n_in + n_rec, n_rec,\n", " w_init=w_init,\n", " b_init=braintools.init.ZeroInit(unit=u.mV)\n", " ),\n", " post=self.neu\n", " )\n", - " self.out = brainscale.nn.LeakyRateReadout(\n", + " self.out = braintrace.nn.LeakyRateReadout(\n", " in_size=n_rec,\n", " out_size=n_out,\n", " tau=tau_o,\n", @@ -102,7 +104,7 @@ "with brainstate.environ.context(dt=0.1 * u.ms):\n", " net = LIF_Delta_Net(n_in=10, n_rec=20, n_out=5)\n", " brainstate.nn.init_all_states(net)\n", - " model = brainscale.D_RTRL(net)\n", + " model = braintrace.D_RTRL(net)\n", " model.compile_graph(brainstate.random.rand(10))\n", " model.show_graph()" ], @@ -231,7 +233,7 @@ " self.n_rec = n_rec\n", "\n", " # 模型层\n", - " self.ir2r = brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA))\n", + " self.ir2r = braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA))\n", " self.exp = brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit(unit=u.mA))\n", " self.r = GIF(\n", " n_rec,\n", @@ -262,7 +264,7 @@ " assert n > 0, \"n_rec should be a list of positive integers.\"\n", " self.layers.append(GifLayer(n_in, n))\n", " n_in = n\n", - " self.out = brainscale.nn.LeakyRateReadout(n_in, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal())\n", + " self.out = braintrace.nn.LeakyRateReadout(n_in, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal())\n", "\n", " def update(self, x):\n", " for layer in self.layers:\n", @@ -285,7 +287,7 @@ "with brainstate.environ.context(dt=0.1 * u.ms):\n", " net2 = GifNet(n_in=10, n_rec=[20, 20, 20], n_out=5)\n", " brainstate.nn.init_all_states(net2)\n", - " model = brainscale.D_RTRL(net2)\n", + " model = braintrace.D_RTRL(net2)\n", " model.compile_graph(brainstate.random.rand(10))\n", " model.show_graph()" ], @@ -325,7 +327,7 @@ "\n", "## Multi-Layer Convolutional Neural Network\n", "\n", - "A demonstration of a multi-layer convolutional architecture built using `brainscale` components. This example showcases how convolutional operations can be integrated with spiking neuron models in a hierarchical structure." + "A demonstration of a multi-layer convolutional architecture built using `braintrace` components. This example showcases how convolutional operations can be integrated with spiking neuron models in a hierarchical structure." ], "id": "40aa51a47fb30b29" }, @@ -374,26 +376,26 @@ " )\n", "\n", " self.layer1 = brainstate.nn.Sequential(\n", - " brainscale.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", - " brainscale.nn.LayerNorm.desc(),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", + " braintrace.nn.LayerNorm.desc(),\n", + " braintrace.nn.IF.desc(**if_param),\n", " brainstate.nn.MaxPool2d.desc(kernel_size=2, stride=2) # 14 * 14\n", " )\n", "\n", " self.layer2 = brainstate.nn.Sequential(\n", - " brainscale.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", - " brainscale.nn.LayerNorm.desc(),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", + " braintrace.nn.LayerNorm.desc(),\n", + " braintrace.nn.IF.desc(**if_param),\n", " )\n", " self.layer3 = brainstate.nn.Sequential(\n", " brainstate.nn.MaxPool2d(kernel_size=2, stride=2, in_size=self.layer2.out_size), # 7 * 7\n", " brainstate.nn.Flatten.desc()\n", " )\n", " self.layer4 = brainstate.nn.Sequential(\n", - " brainscale.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits),\n", + " braintrace.nn.IF.desc(**if_param),\n", " )\n", - " self.layer5 = brainscale.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o)\n", + " self.layer5 = braintrace.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o)\n", "\n", " def update(self, x):\n", " # x.shape = [B, H, W, C]\n", @@ -415,7 +417,7 @@ "with brainstate.environ.context(dt=0.1):\n", " net2 = ConvSNN((34, 34, 2), 10)\n", " brainstate.nn.init_all_states(net2)\n", - " model = brainscale.D_RTRL(net2)\n", + " model = braintrace.D_RTRL(net2)\n", " model.compile_graph(brainstate.random.random((34, 34, 2)))\n", " model.show_graph()" ], diff --git a/docs/tutorial/show_graph-zh.ipynb b/docs/tutorial/show_graph-zh.ipynb index b71b557..e01c5fe 100644 --- a/docs/tutorial/show_graph-zh.ipynb +++ b/docs/tutorial/show_graph-zh.ipynb @@ -9,7 +9,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "brainscale 利用中间表达(Intermediate Representation, IR)分析的方式来抽取神经元状态、连接、和参数之间的依赖关系。通过``.show_graph()``的函数调用,brainscale支持可视化编译图,帮助更好地理解神经元模型的计算过程和模型关系。", + "source": "braintrace 利用中间表达(Intermediate Representation, IR)分析的方式来抽取神经元状态、连接、和参数之间的依赖关系。通过``.show_graph()``的函数调用,braintrace支持可视化编译图,帮助更好地理解神经元模型的计算过程和模型关系。", "id": "6a218956c3ceda62" }, { @@ -23,8 +23,10 @@ "source": [ "import brainstate\n", "import brainunit as u\n", - "import brainscale\n", - "import jax" + "import braintrace\n", + "import jax\n", + "import brainpy\n", + "import braintools" ], "id": "815c5158c7221ab1", "outputs": [], @@ -67,14 +69,14 @@ " ff_init = braintools.init.KaimingNormal(ff_scale, unit=u.mV)\n", " w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0)\n", " self.syn = brainpy.state.DeltaProj(\n", - " comm=brainscale.nn.Linear(\n", + " comm=braintrace.nn.Linear(\n", " n_in + n_rec, n_rec,\n", " w_init=w_init,\n", " b_init=braintools.init.ZeroInit(unit=u.mV)\n", " ),\n", " post=self.neu\n", " )\n", - " self.out = brainscale.nn.LeakyRateReadout(\n", + " self.out = braintrace.nn.LeakyRateReadout(\n", " in_size=n_rec,\n", " out_size=n_out,\n", " tau=tau_o,\n", @@ -101,7 +103,7 @@ "with brainstate.environ.context(dt=0.1 * u.ms):\n", " net = LIF_Delta_Net(n_in=10, n_rec=20, n_out=5)\n", " brainstate.nn.init_all_states(net)\n", - " model = brainscale.D_RTRL(net)\n", + " model = braintrace.D_RTRL(net)\n", " model.compile_graph(brainstate.random.rand(10))\n", " model.show_graph()" ], @@ -231,7 +233,7 @@ " self.n_rec = n_rec\n", "\n", " # 模型层\n", - " self.ir2r = brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA))\n", + " self.ir2r = braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA))\n", " self.exp = brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit(unit=u.mA))\n", " self.r = GIF(\n", " n_rec,\n", @@ -262,7 +264,7 @@ " assert n > 0, \"n_rec should be a list of positive integers.\"\n", " self.layers.append(GifLayer(n_in, n))\n", " n_in = n\n", - " self.out = brainscale.nn.LeakyRateReadout(n_in, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal())\n", + " self.out = braintrace.nn.LeakyRateReadout(n_in, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal())\n", "\n", " def update(self, x):\n", " for layer in self.layers:\n", @@ -285,7 +287,7 @@ "with brainstate.environ.context(dt=0.1 * u.ms):\n", " net2 = GifNet(n_in=10, n_rec=[20, 20, 20], n_out=5)\n", " brainstate.nn.init_all_states(net2)\n", - " model = brainscale.D_RTRL(net2)\n", + " model = braintrace.D_RTRL(net2)\n", " model.compile_graph(brainstate.random.rand(10))\n", " model.show_graph()" ], @@ -369,26 +371,26 @@ " )\n", "\n", " self.layer1 = brainstate.nn.Sequential(\n", - " brainscale.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", - " brainscale.nn.LayerNorm.desc(),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", + " braintrace.nn.LayerNorm.desc(),\n", + " braintrace.nn.IF.desc(**if_param),\n", " brainstate.nn.MaxPool2d.desc(kernel_size=2, stride=2) # 14 * 14\n", " )\n", "\n", " self.layer2 = brainstate.nn.Sequential(\n", - " brainscale.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", - " brainscale.nn.LayerNorm.desc(),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits),\n", + " braintrace.nn.LayerNorm.desc(),\n", + " braintrace.nn.IF.desc(**if_param),\n", " )\n", " self.layer3 = brainstate.nn.Sequential(\n", " brainstate.nn.MaxPool2d(kernel_size=2, stride=2, in_size=self.layer2.out_size), # 7 * 7\n", " brainstate.nn.Flatten.desc()\n", " )\n", " self.layer4 = brainstate.nn.Sequential(\n", - " brainscale.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits),\n", - " brainscale.nn.IF.desc(**if_param),\n", + " braintrace.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits),\n", + " braintrace.nn.IF.desc(**if_param),\n", " )\n", - " self.layer5 = brainscale.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o)\n", + " self.layer5 = braintrace.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o)\n", "\n", " def update(self, x):\n", " # x.shape = [B, H, W, C]\n", @@ -410,7 +412,7 @@ "with brainstate.environ.context(dt=0.1):\n", " net2 = ConvSNN((34, 34, 2), 10)\n", " brainstate.nn.init_all_states(net2)\n", - " model = brainscale.D_RTRL(net2)\n", + " model = braintrace.D_RTRL(net2)\n", " model.compile_graph(brainstate.random.random((34, 34, 2)))\n", " model.show_graph()" ], diff --git a/examples/000-lif-snn-for-nmnist.py b/examples/000-lif-snn-for-nmnist.py index 7341413..58049a2 100644 --- a/examples/000-lif-snn-for-nmnist.py +++ b/examples/000-lif-snn-for-nmnist.py @@ -14,7 +14,7 @@ # ============================================================================== -# see brainscale documentations for more details. +# see braintrace documentations for more details. import brainstate import braintools diff --git a/examples/001-gif-snn-for-dms.py b/examples/001-gif-snn-for-dms.py index 3d3b235..8348e9c 100644 --- a/examples/001-gif-snn-for-dms.py +++ b/examples/001-gif-snn-for-dms.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -# see brainscale documentations for more details. +# see braintrace documentations for more details. import brainstate import braintools diff --git a/examples/002-coba-ei-rsnn.py b/examples/002-coba-ei-rsnn.py index 7646173..0bb5a8b 100644 --- a/examples/002-coba-ei-rsnn.py +++ b/examples/002-coba-ei-rsnn.py @@ -25,7 +25,7 @@ import matplotlib.pyplot as plt import numpy as np -import brainscale +import braintrace class EvidenceAccumulation: @@ -235,7 +235,7 @@ def __init__( self.pop = GIF(n_rec, tau=tau_neu, tau_I2=tau_a, A2=beta) # feedforward self.ff2r = brainpy.state.AlignPostProj( - comm=brainscale.nn.SignedWLinear( + comm=braintrace.nn.SignedWLinear( n_in, n_rec, w_init=braintools.init.KaimingNormal(ff_scale, unit=u.siemens) @@ -250,7 +250,7 @@ def __init__( ) # recurrent inh_init = braintools.init.KaimingNormal(scale=rec_scale * w_ei_ratio, unit=u.siemens) - inh2r_conn = brainscale.nn.SignedWLinear( + inh2r_conn = braintrace.nn.SignedWLinear( self.n_inh, n_rec, w_init=inh_init, @@ -267,7 +267,7 @@ def __init__( post=self.pop ) exc_init = braintools.init.KaimingNormal(scale=rec_scale, unit=u.siemens) - exc2r_conn = brainscale.nn.SignedWLinear(self.n_exc, n_rec, w_init=exc_init) + exc2r_conn = braintrace.nn.SignedWLinear(self.n_exc, n_rec, w_init=exc_init) self.exc2r = brainpy.state.AlignPostProj( comm=exc2r_conn, syn=brainpy.state.Expon.desc( @@ -279,7 +279,7 @@ def __init__( post=self.pop ) # output - self.out = brainscale.nn.LeakyRateReadout(n_rec, n_out, tau=tau_out) + self.out = braintrace.nn.LeakyRateReadout(n_rec, n_out, tau=tau_out) def update(self, spk): e_sps, i_sps = jnp.split(self.pop.get_spike(), [self.n_exc], axis=-1) @@ -430,11 +430,11 @@ def etrace_train(self, inputs, targets): # initialize the online learning model if self.method == 'expsm_diag': - model = brainscale.IODimVjpAlgorithm(self.target, decay_or_rank=0.99) + model = braintrace.IODimVjpAlgorithm(self.target, decay_or_rank=0.99) elif self.method == 'diag': - model = brainscale.ParamDimVjpAlgorithm(self.target) + model = braintrace.ParamDimVjpAlgorithm(self.target) elif self.method == 'hybrid': - model = brainscale.HybridDimVjpAlgorithm(self.target, decay_or_rank=0.99) + model = braintrace.HybridDimVjpAlgorithm(self.target, decay_or_rank=0.99) else: raise ValueError(f'Unknown online learning methods: {self.method}.') diff --git a/examples/003-snn-memory-and-speed-evaluation-all.py b/examples/003-snn-memory-and-speed-evaluation-all.py index f663895..06ec25d 100644 --- a/examples/003-snn-memory-and-speed-evaluation-all.py +++ b/examples/003-snn-memory-and-speed-evaluation-all.py @@ -27,7 +27,7 @@ import jax.numpy as jnp import numpy as np -import brainscale +import braintrace default_setting = brainstate.util.DotDict( method='bptt', @@ -116,7 +116,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), post=self.neu ) @@ -160,7 +160,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init), syn=brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit()), out=brainpy.state.CUBA(scale=1.), post=self.neu @@ -238,7 +238,7 @@ def __init__( self.rec_layers.append(rec) # output layer - self.out = brainscale.nn.LeakyRateReadout( + self.out = braintrace.nn.LeakyRateReadout( in_size=n_rec, out_size=n_out, tau=args.tau_o, @@ -353,11 +353,11 @@ def _step(i, inp): def _compile_etrace_function(self, input_info): if self.args.method == 'expsm_diag': - model = brainscale.ES_D_RTRL(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) + model = braintrace.ES_D_RTRL(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) elif self.args.method == 'diag': - model = brainscale.D_RTRL(self.target, mode=brainstate.mixin.Batching()) + model = braintrace.D_RTRL(self.target, mode=brainstate.mixin.Batching()) elif self.args.method == 'hybrid': - model = brainscale.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, + model = braintrace.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) else: raise ValueError(f'Unknown online learning methods: {self.args.method}.') diff --git a/examples/003-snn-memory-and-speed-evaluation-batched.py b/examples/003-snn-memory-and-speed-evaluation-batched.py index 04727d0..ce88cdd 100644 --- a/examples/003-snn-memory-and-speed-evaluation-batched.py +++ b/examples/003-snn-memory-and-speed-evaluation-batched.py @@ -84,7 +84,7 @@ global_args = parser.parse_args() -import brainscale +import braintrace import brainstate import braintools import brainunit as u @@ -172,7 +172,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), post=self.neu ) @@ -216,7 +216,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init), syn=brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit()), out=brainpy.state.CUBA(scale=1.), post=self.neu @@ -294,7 +294,7 @@ def __init__( self.rec_layers.append(rec) # output layer - self.out = brainscale.nn.LeakyRateReadout( + self.out = braintrace.nn.LeakyRateReadout( in_size=n_rec, out_size=n_out, tau=args.tau_o, @@ -474,11 +474,11 @@ def _step(i, inp): def _compile_etrace_function(self, input_info): if self.args.method == 'expsm_diag': - model = brainscale.ES_D_RTRL(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) + model = braintrace.ES_D_RTRL(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) elif self.args.method == 'diag': - model = brainscale.D_RTRL(self.target, mode=brainstate.mixin.Batching()) + model = braintrace.D_RTRL(self.target, mode=brainstate.mixin.Batching()) elif self.args.method == 'hybrid': - model = brainscale.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, + model = braintrace.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, mode=brainstate.mixin.Batching()) else: raise ValueError(f'Unknown online learning methods: {self.args.method}.') diff --git a/examples/003-snn-memory-and-speed-evaluation-vmap.py b/examples/003-snn-memory-and-speed-evaluation-vmap.py index cf1f427..ebe6a78 100644 --- a/examples/003-snn-memory-and-speed-evaluation-vmap.py +++ b/examples/003-snn-memory-and-speed-evaluation-vmap.py @@ -70,7 +70,7 @@ global_args = parser.parse_args() import brainpy -import brainscale +import braintrace import brainstate import braintools import brainunit as u @@ -154,7 +154,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init), post=self.neu ) @@ -198,7 +198,7 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale) w_init = jnp.concat([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.AlignPostProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init), + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init), syn=brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit()), out=brainpy.state.CUBA(scale=1.), post=self.neu @@ -277,7 +277,7 @@ def __init__( self.rec_layers.append(rec) # output layer - self.out = brainscale.nn.LeakyRateReadout( + self.out = braintrace.nn.LeakyRateReadout( in_size=n_rec, out_size=n_out, tau=args.tau_o, @@ -457,11 +457,11 @@ def _step(i, inp): def _compile_etrace_function(self, input_info): if self.args.method == 'expsm_diag': - model = brainscale.ES_D_RTRL(self.target, self.args.etrace_decay, ) + model = braintrace.ES_D_RTRL(self.target, self.args.etrace_decay, ) elif self.args.method == 'diag': - model = brainscale.D_RTRL(self.target, ) + model = braintrace.D_RTRL(self.target, ) elif self.args.method == 'hybrid': - model = brainscale.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, ) + model = braintrace.HybridDimVjpAlgorithm(self.target, self.args.etrace_decay, ) else: raise ValueError(f'Unknown online learning methods: {self.args.method}.') diff --git a/examples/004-feedforward-conv-snn.py b/examples/004-feedforward-conv-snn.py index fc031ab..381985c 100644 --- a/examples/004-feedforward-conv-snn.py +++ b/examples/004-feedforward-conv-snn.py @@ -31,7 +31,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm -import brainscale +import braintrace class ConvSNN(brainstate.nn.Module): @@ -70,15 +70,15 @@ def __init__( ) self.layer1 = brainstate.nn.Sequential( - brainscale.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits), - brainscale.nn.LayerNorm.desc(), + braintrace.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits), + braintrace.nn.LayerNorm.desc(), brainpy.state.IF.desc(**if_param), brainstate.nn.MaxPool2d.desc(kernel_size=2, stride=2) # 14 * 14 ) self.layer2 = brainstate.nn.Sequential( - brainscale.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits), - brainscale.nn.LayerNorm.desc(), + braintrace.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits), + braintrace.nn.LayerNorm.desc(), brainpy.state.IF.desc(**if_param), ) self.layer3 = brainstate.nn.Sequential( @@ -86,10 +86,10 @@ def __init__( brainstate.nn.Flatten.desc() ) self.layer4 = brainstate.nn.Sequential( - brainscale.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits), + braintrace.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits), brainpy.state.IF.desc(**if_param), ) - self.layer5 = brainscale.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o) + self.layer5 = braintrace.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o) def update(self, x): # x.shape = [B, H, W, C] @@ -222,8 +222,8 @@ def batch_train(self, inputs, targets): # inputs: [n_step, n_batch, ...] # targets: [n_batch, n_out] - # model = brainscale.ES_D_RTRL(self.target, self.decay_or_rank) - model = brainscale.D_RTRL(self.target) + # model = braintrace.ES_D_RTRL(self.target, self.decay_or_rank) + model = braintrace.D_RTRL(self.target) @brainstate.transform.vmap_new_states( state_tag='new', @@ -278,8 +278,8 @@ def batch_train(self, inputs, targets): # inputs: [n_step, n_batch, ...] # targets: [n_batch, n_out] - # model = brainscale.ES_D_RTRL(self.target, self.decay_or_rank, model=brainstate.mixin.Batching()) - model = brainscale.D_RTRL(self.target, self.decay_or_rank, model=brainstate.mixin.Batching()) + # model = braintrace.ES_D_RTRL(self.target, self.decay_or_rank, model=brainstate.mixin.Batching()) + model = braintrace.D_RTRL(self.target, self.decay_or_rank, model=brainstate.mixin.Batching()) # initialize the online learning model brainstate.nn.init_all_states(self.target, batch_size=inputs.shape[1]) diff --git a/examples/100-gru-on-copying-task.py b/examples/100-gru-on-copying-task.py index 43ca3ff..410d7cc 100644 --- a/examples/100-gru-on-copying-task.py +++ b/examples/100-gru-on-copying-task.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -# See brainscale documentation for more details: +# See braintrace documentation for more details: import brainstate import braintools @@ -22,7 +22,7 @@ import numpy as np from tqdm import tqdm -import brainscale +import braintrace class CopyDataset: @@ -52,11 +52,11 @@ def __init__(self, n_in, n_rec, n_out, n_layer): # 构建GRU多层网络 layers = [] for _ in range(n_layer): - layers.append(brainscale.nn.GRUCell(n_in, n_rec)) + layers.append(braintrace.nn.GRUCell(n_in, n_rec)) n_in = n_rec self.layer = brainstate.nn.Sequential(*layers) # 构建输出层 - self.readout = brainscale.nn.Linear(n_rec, n_out) + self.readout = braintrace.nn.Linear(n_rec, n_out) def update(self, x): return self.readout(self.layer(x)) @@ -119,7 +119,7 @@ def batch_train(self, inputs, target): if self.batch_train_method == 'vmap': # 初始化在线学习模型 # 此处,我们需要使用 mode 来指定使用数据集是具有 batch 维度的 - model = brainscale.ParamDimVjpAlgorithm(self.target, vjp_method=self.vjp_method) + model = braintrace.ParamDimVjpAlgorithm(self.target, vjp_method=self.vjp_method) @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1]) def init(): @@ -132,7 +132,7 @@ def init(): model = brainstate.nn.Vmap(model, vmap_states='new') elif self.batch_train_method == 'batch': - model = brainscale.ParamDimVjpAlgorithm( + model = braintrace.ParamDimVjpAlgorithm( self.target, vjp_method=self.vjp_method, mode=brainstate.mixin.Batching()) brainstate.nn.init_all_states(self.target, batch_size=inputs.shape[1]) model.compile_graph(inputs[0]) diff --git a/examples/README.md b/examples/README.md index 3663871..c6801c5 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,9 +1,9 @@ -# Examples of ``brainscale`` +# Examples of ``braintrace`` -We provide the following kinds of examples to demonstrate the use of the ``brainscale`` package. +We provide the following kinds of examples to demonstrate the use of the ``braintrace`` package. -- The application fo the ``brainscale`` package to spiking neural networks, with filenames starting with ``0xx_``. -- The application of the ``brainscale`` package to rate-based recurrent neural networks, with filenames starting with ``1xx_``. +- The application fo the ``braintrace`` package to spiking neural networks, with filenames starting with ``0xx_``. +- The application of the ``braintrace`` package to rate-based recurrent neural networks, with filenames starting with ``1xx_``. All the python files can be directly run in the terminal. For example, to run the file ``001-gif-snn-for-dms.py``, you can use the following command: @@ -20,7 +20,7 @@ The examples require the following packages: - numpy - matplotlib -- brainscale +- braintrace - brainunit - brainstate - tqdm diff --git a/examples/snn_models.py b/examples/snn_models.py index a969565..8ac5cd0 100644 --- a/examples/snn_models.py +++ b/examples/snn_models.py @@ -27,7 +27,7 @@ import numpy as np from tqdm import tqdm -import brainscale +import braintrace LOSS = float ACCURACY = float @@ -117,7 +117,7 @@ def __init__( self.n_out = n_out # 模型层 - self.ir2r = brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA)) + self.ir2r = braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=braintools.init.ZeroInit(unit=u.mA)) self.exp = brainpy.state.Expon(n_rec, tau=tau_syn, g_initializer=braintools.init.ZeroInit(unit=u.mA)) self.r = GIF( n_rec, @@ -127,7 +127,7 @@ def __init__( tau=tau_neu, tau_I2=brainstate.random.uniform(100. * u.ms, tau_I2 * 1.5, n_rec), ) - self.out = brainscale.nn.LeakyRateReadout(n_rec, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal()) + self.out = braintrace.nn.LeakyRateReadout(n_rec, n_out, tau=tau_o, w_init=braintools.init.KaimingNormal()) def update(self, spikes): cond = self.ir2r(u.math.concatenate([spikes, self.r.get_spike()], axis=-1)) @@ -404,7 +404,7 @@ def batch_train(self, inputs, targets): weights = self.target.states().subset(brainstate.ParamState) # initialize the online learning model - model = brainscale.IODimVjpAlgorithm(self.target, self.decay_or_rank) + model = braintrace.IODimVjpAlgorithm(self.target, self.decay_or_rank) # initialize the states @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1]) @@ -504,12 +504,12 @@ def __init__( ff_init: Callable = braintools.init.KaimingNormal(ff_scale, unit=u.mV) w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0) self.syn = brainpy.state.DeltaProj( - comm=brainscale.nn.Linear(n_in + n_rec, n_rec, + comm=braintrace.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=braintools.init.ZeroInit(unit=u.mV)), post=self.neu ) - self.out = brainscale.nn.LeakyRateReadout(in_size=n_rec, + self.out = braintrace.nn.LeakyRateReadout(in_size=n_rec, out_size=n_out, tau=tau_o, w_init=braintools.init.KaimingNormal()) diff --git a/pyproject.toml b/pyproject.toml index b2c8f35..d523f67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,14 +13,14 @@ exclude = [ "build*", "dist*", "dev*", - "brainscale.egg-info*", - "brainscale/__pycache__*", - "brainscale/__init__.py", + "braintrace.egg-info*", + "braintrace/__pycache__*", + "braintrace/__init__.py", ] [tool.setuptools.dynamic] -version = { attr = "brainscale.__version__" } +version = { attr = "braintrace.__version__" } [tool.distutils.bdist_wheel] @@ -28,12 +28,12 @@ universal = true [project] -name = "brainscale" +name = "braintrace" description = "Enabling Scalable Online Learning for Brain Dynamics." readme = "README.md" license = { text = "Apache-2.0 license" } requires-python = ">=3.10" -authors = [{ name = "BrainScale Developers", email = "chao.brain@qq.com" }] +authors = [{ name = "BrainTrace Developers", email = "chao.brain@qq.com" }] keywords = [ "computational neuroscience", "brain-inspired computing", @@ -67,10 +67,10 @@ dynamic = ["version"] [project.urls] -"Homepage" = "https://github.com/chaobrain/brainscale" -"Bug Tracker" = "https://github.com/chaobrain/brainscale/issues" -"Documentation" = "https://brainscale.readthedocs.io/" -"Source Code" = "https://github.com/chaobrain/brainscale" +"Homepage" = "https://github.com/chaobrain/braintrace" +"Bug Tracker" = "https://github.com/chaobrain/braintrace/issues" +"Documentation" = "https://braintrace.readthedocs.io/" +"Source Code" = "https://github.com/chaobrain/braintrace" [project.optional-dependencies]