Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in calculate_contribution_scores_regions with pytorch #45

Open
lldelisle opened this issue Oct 16, 2024 · 8 comments
Open

Error in calculate_contribution_scores_regions with pytorch #45

lldelisle opened this issue Oct 16, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@lldelisle
Copy link
Contributor

Report

Hi,
I got the following error when trying to use calculate_contribution_scores_regions with pytorch:

2024-10-16T15:33:14.181525+0200 INFO Calculating contribution scores for 4 class(es) and 2 region(s).
Region:   0%|                                                                                                                         | 0/2 [00:00<?, ?it/s]/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py:255: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.
  return np.array(x_shuffle)
/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py:255: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return np.array(x_shuffle)
Region:   0%|                                                                                                                         | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/ldelisle/scripts/scitas_sbatchhistory/2024/20241002_tryCREest/20241016_plots_with_model.py", line 121, in <module>
    scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_crested.py", line 994, in calculate_contribution_scores_regions
    return self.calculate_contribution_scores_sequence(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_crested.py", line 1085, in calculate_contribution_scores_sequence
    scores[:, i, :, :] = explainer.expected_integrated_grad(
  File "/scratch/izar/ldelisle/CREsted_test1/venv9/lib/python3.10/site-packages/crested/tl/_explainer_torch.py", line 71, in expected_integrated_grad
    baselines = torch.tensor(
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.

If you need more info, tell me I can try to make a minimal example.

Version information

I don't have session_info but I can give you the output of pip list:

Package                           Version                                                                                                                   
--------------------------------- ------------                                                                                                              
absl-py                           1.2.0                                                                                                                     
anndata                           0.10.5.post1                                                                                                              
appdirs                           1.4.4                                                                                                                     
array_api_compat                  1.9                                                                                                                       
astunparse                        1.6.3                                                                                                                     
backports.entry-points-selectable 1.1.1                                                                                                                     
certifi                           2021.10.8                                                                                                                 
charset-normalizer                2.0.12                                                                                                                    
click                             8.1.7                                                                                                                     
crested                           1.1.0                                                                                                                     
cycler                            0.11.0                                                                                                                    
Cython                            0.29.30                                                                                                                   
distlib                           0.3.4  
docker-pycreds                    0.4.0                                                                                                                     
exceptiongroup                    1.2.2                                                                                                                     
filelock                          3.5.0                                                                                                                     
fonttools                         4.54.1                                                                                                                    
fsspec                            2024.9.0                                                                                                                  
gast                              0.5.3                                                                                                                     
gitdb                             4.0.11                                                                                                                    
GitPython                         3.1.43                                                                                                                    
google-pasta                      0.2.0                                                                                                                     
h5py                              3.12.1                                                                                                                    
idna                              3.3                                                                                                                       
Jinja2                            3.1.4                                                                                                                     
joblib                            1.4.2                                                                                                                     
keras                             3.6.0                                                                                                                     
Keras-Preprocessing               1.1.2                                                                                                                     
kiwisolver                        1.3.2                                                                                                                     
logomaker                         0.8                                                                                                                       
loguru                            0.7.2                                                                                                                     
markdown-it-py                    3.0.0                                                                                                                     
MarkupSafe                        3.0.1                                                                                                                     
matplotlib                        3.5.2                                                                                                                     
mdurl                             0.1.2                                                                                                                     
ml_dtypes                         0.5.0                                                                                                                     
mpmath                            1.2.1                                                                                                                     
namex                             0.0.8                                                                                                                     
natsort                           8.4.0                                                                                                                     
networkx                          3.3                                                                                                                       
numpy                             1.22.4  
nvidia-cuda-cupti-cu12            12.1.105                                                                                                        
nvidia-cuda-nvrtc-cu12            12.1.105                                                                                                                  
nvidia-cuda-runtime-cu12          12.1.105                                                                                                                  
nvidia-cudnn-cu12                 9.1.0.70                                                                                                                  
nvidia-cufft-cu12                 11.0.2.54                                                                                                                 
nvidia-curand-cu12                10.3.2.106                                                                                                                
nvidia-cusolver-cu12              11.4.5.107                                                                                                                
nvidia-cusparse-cu12              12.1.0.106                                                                                                                
nvidia-nccl-cu12                  2.20.5                                                                                                                    
nvidia-nvjitlink-cu12             12.6.77                                                                                                                   
nvidia-nvtx-cu12                  12.1.105                                                                                                                  
opt-einsum                        3.3.0                                                                                                                     
optree                            0.13.0                                                                                                                    
packaging                         21.3                                                                                                                      
pandas                            1.4.2                                                                                                                     
Pillow                            9.0.0                                                                                                                     
pip                               24.2                                                                                                                      
platformdirs                      2.4.0                                                                                                                     
ply                               3.11                                                                                                                      
pooch                             1.6.0                                                                                                                     
protobuf                          3.20.0                                                                                                                    
psutil                            6.0.0                                                                                                                     
pybigtools                        0.2.1                                                                                                                     
pyBigWig                          0.3.23                                                                                                                    
Pygments                          2.18.0  
pyparsing                         3.0.6
pysam                             0.22.1
python-dateutil                   2.8.2
pytz                              2021.3
PyYAML                            6.0.2
requests                          2.26.0
rich                              13.9.2
scikit-learn                      1.5.2
scipy                             1.8.1
seaborn                           0.13.2
semver                            2.8.1
sentry-sdk                        1.9.0
setproctitle                      1.3.3
setuptools                        58.3.0
six                               1.16.0
smmap                             5.0.1
sphire                            1.4.1
sympy                             1.8
termcolor                         1.1.0
threadpoolctl                     3.5.0
torch                             2.4.1
tqdm                              4.66.5
triton                            3.0.0
typing_extensions                 4.12.2
urllib3                           1.26.6
virtualenv                        20.10.0
wandb                             0.18.3
wheel                             0.44.0
wrapt                             1.13.3
xarray                            2022.3.0

@lldelisle lldelisle added the bug Something isn't working label Oct 16, 2024
@lldelisle
Copy link
Contributor Author

(I don't think it is the origin of the bug but just to let you know, I trained the model with version 1.0.0 and upgraded to 1.1.0 to plot)

@LukasMahieu
Copy link
Collaborator

Haven't seen this before, we normally have unit tests for this. I'll try to reproduce asap.

@lldelisle
Copy link
Contributor Author

I have numpy 1.22.4, I think this is the issue...

@LukasMahieu
Copy link
Collaborator

Actually I have seen this before. I remember that upgrading to numpy 2.+ indeed fixed this. How did you create your environment? When I create a standard crested environment with python 3.10 then numpy 2.1.2 gets installed and I don't run into any issues.
For python 3.10 numpy also recommend version 1.23+

@lldelisle
Copy link
Contributor Author

I am a newbie in GPU so I made the virtual environment on my HPC using module load to be sure pytorch will recognize the GPU and then tried to install other dependencies without overwritting existing installation.
Here are the command lines:

module purge
module load gcc/11.3.0 python/3.10.4 openmpi/4.1.3-cuda
# With openmpi comes cuda/11.8.0
# Only the first time:
virtualenv --system-site-packages venv9
# Activate
source venv9/bin/activate
pip install --no-cache-dir torch
pip install --no-cache-dir keras
pip install --no-cache-dir crested urllib3==1.26.6 numpy=='1.22.4' platformdirs=='2.4.0' sentry_sdk==1.9.0

I've just created a new virtualenv with:

# Try pytorch with pip
module purge
module load gcc/11.3.0 python/3.10.4
# Only the first time:
virtualenv venv7
# Activate
source venv7/bin/activate

pip install --no-cache-dir torch
pip install --no-cache-dir crested

And it solved the issue (but I got another bug, I will write a different issue).

Then I would say that you should set numpy>2 in dependencies no?

@LukasMahieu
Copy link
Collaborator

Then I would say that you should set numpy>2 in dependencies no?

Yes, I'll do just this. Thanks for bringing this up.

@lldelisle
Copy link
Contributor Author

You are welcome. I like when people report bugs to my project so I prevent other people to face the same bug but being too shy to report it and not use the package... I hope you are in the same spirit because it seems that I faced other bugs...

@LukasMahieu
Copy link
Collaborator

Absolutely, feedback is greatly appreciated!
I just realized the solution is not as simple here after all, since torch would require >2 but tensorflow only works with <2 🙃 .
I'll see if I can fix this in the code itself early next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants