-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
- minimal code for reproduce the error:
import torch
import trak
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector
print("trak.test_install:", trak.test_install(use_fast_jl=True))
grad_dim = int(1e6)
projector = CudaProjector(
grad_dim=grad_dim,
proj_dim=32768,
seed=42,
proj_type=ProjectionType.normal,
device='cuda:0',
max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)
- env installation code
pip install scikit-learn matplotlib einops ipykernel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
conda install cuda=12.1 -c nvidia
conda install cuda-nvcc=12.1 -c nvidia -y
conda install cuda-toolkit=12.1 -c nvidia -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]
Metadata
Metadata
Assignees
Labels
No labels