Skip to content

Commit

Permalink
Merge pull request #3387 from AlexsLemonade/ark/3361-tracking-data-re…
Browse files Browse the repository at this point in the history
…quests

Implement data request tracking
  • Loading branch information
arkid15r authored Sep 26, 2023
2 parents 40e8463 + cac1ead commit bc365b2
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import tempfile
from collections import Counter

from django.core.management.base import BaseCommand

import boto3
import botocore

from data_refinery_common.models import DownloaderJob, Experiment, ProcessorJob, SurveyJob
from data_refinery_common.utils import parse_s3_url


class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
"--file",
type=str,
help="A file listing accession codes. s3:// URLs only are accepted.",
)

def handle(self, *args, **options):
"""Creates a report based on the processing results."""
experiments_attempted = 0
experiments_attempted_by_source = Counter()
experiments_available = 0
experiments_available_by_source = Counter()
jobs_created = 0
jobs_created_by_source = Counter()
samples_attempted = 0
samples_attempted_by_source = Counter()
samples_available = 0
samples_available_by_source = Counter()

accessions = self.get_accessions(options["file"])
for accession in accessions or ():
experiment = Experiment.objects.filter(accession_code=accession).first()
if not experiment:
continue

source = experiment.source_database

# Experiments attempted, total and breakdown by source.
experiments_attempted += 1
experiments_attempted_by_source.update({source: 1})

# Samples attempted, total and breakdown by source.
samples_count = experiment.samples.count()
samples_attempted += samples_count
samples_attempted_by_source.update({source: samples_count})

# Experiments available, total and breakdown by source.
processed_samples_count = experiment.samples.filter(is_processed=True).count()
if processed_samples_count > 0:
experiments_available += 1
experiments_available_by_source.update({source: 1})

# Samples available, total and breakdown by source.
samples_available += processed_samples_count
samples_available_by_source.update({source: processed_samples_count})

sample_accessions = experiment.samples.values_list("accession_code", flat=True)
# Total number of jobs created, breakdown by source.
downloader_jobs_created = self.get_downloader_jobs(sample_accessions).count()
processor_jobs_created = self.get_processor_jobs(sample_accessions).count()
survey_jobs_created = self.get_surveyor_jobs(experiment).count()
total_jobs_created = (
downloader_jobs_created + processor_jobs_created + survey_jobs_created
)

jobs_created += total_jobs_created
jobs_created_by_source.update({source: total_jobs_created})

# TODO(arkid15r): Calculate total run time as a difference between first created job
# and the last finished job. Also indicate if there is something that still needs to
# be processed.

output = []
if experiments_attempted:
output += [
f"Experiments attempted: {experiments_attempted}",
f"{self.get_distribution_by_source(experiments_attempted_by_source)}",
"",
]

output.append(f"Samples attempted: {samples_attempted}")
if samples_attempted:
output.append(f"{self.get_distribution_by_source(samples_attempted_by_source)}")
output.append("")

output.append(f"Experiments available: {experiments_available}")
if experiments_available:
output.append(f"{self.get_distribution_by_source(experiments_available_by_source)}")
output.append("")

output.append(f"Samples available: {samples_available}")
if samples_available:
output.append(f"{self.get_distribution_by_source(samples_available_by_source)}")
output.append("")

output.append(f"Total jobs: {total_jobs_created}")
elif accessions is not None:
output.append("No experiments found")

if output:
print("\n".join(output))

@staticmethod
def get_accessions(s3_url):
"""Gets source experiment accessions."""
if not s3_url.startswith("s3://"):
print("Please provide a valid S3 URL")
return None

with tempfile.TemporaryFile() as tmp_file:
bucket, key = parse_s3_url(s3_url)
try:
boto3.resource("s3").meta.client.download_fileobj(bucket, key, tmp_file)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The S3 file does not exist.")
return None

return (line.strip() for line in tmp_file.readlines() if line.strip())

@staticmethod
def get_downloader_jobs(sample_accessions):
"""Returns downloader jobs for sample accessions."""
return DownloaderJob.objects.filter(
original_files__samples__accession_code__in=sample_accessions
).distinct()

@staticmethod
def get_processor_jobs(sample_accessions):
"""Returns processor jobs for sample accessions."""
return ProcessorJob.objects.filter(
original_files__samples__accession_code__in=sample_accessions
).distinct()

@staticmethod
def get_surveyor_jobs(experiment):
"""Returns surveyor jobs for an experiment."""
return SurveyJob.objects.filter(
surveyjobkeyvalue__key="experiment_accession_code",
surveyjobkeyvalue__value=experiment.accession_code,
)

@staticmethod
def get_distribution_by_source(stats):
"""Returns a source based stats in a consistent manner."""
return ", ".join((f"{source}: {stats[source]}" for source in sorted(stats)))
187 changes: 187 additions & 0 deletions foreman/tests/foreman/management/commands/test_track_data_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from unittest import mock

from django.core.management import call_command
from django.test import TestCase

from data_refinery_common.models import (
DownloaderJob,
DownloaderJobOriginalFileAssociation,
Experiment,
ExperimentSampleAssociation,
OriginalFile,
OriginalFileSampleAssociation,
ProcessorJob,
ProcessorJobOriginalFileAssociation,
Sample,
SurveyJob,
SurveyJobKeyValue,
)


class TestTrackDataRequest(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()

experiment = Experiment.objects.create(
accession_code="GSE12417", technology="MICROARRAY", source_database="GEO"
)

cls.sample1 = Sample.objects.create(
accession_code="GSM311750",
source_database="GEO",
technology="MICROARRAY",
platform_accession_code="hgu133a",
)
cls.sample2 = Sample.objects.create(
accession_code="GSM311751",
source_database="GEO",
technology="MICROARRAY",
platform_accession_code="hgu133plus2",
)
ExperimentSampleAssociation.objects.create(experiment=experiment, sample=cls.sample1)
ExperimentSampleAssociation.objects.create(experiment=experiment, sample=cls.sample2)

def call_command(self, experiments=(), **kwargs):
with mock.patch(
f"data_refinery_foreman.foreman.management.commands.track_data_request.Command.get_accessions",
return_value=experiments,
):
call_command("track_data_request", (), **{"file": "s3://some/file.txt"})

@mock.patch("builtins.print")
def test_invalid_file(self, output):
call_command("track_data_request", (), **{"file": "s4://some/file.txt"})
output.assert_called_with("Please provide a valid S3 URL")

@mock.patch("builtins.print")
def test_no_experiments(self, output):
self.call_command(experiments=())
output.assert_called_once_with("No experiments found")

@mock.patch("builtins.print")
def test_no_experiments_available(self, output):
self.call_command(experiments=("GSE12417",))
output.assert_called_once_with(
"\n".join(
(
"Experiments attempted: 1",
"GEO: 1",
"",
"Samples attempted: 2",
"GEO: 2",
"",
"Experiments available: 0",
"",
"Samples available: 0",
"",
"Total jobs: 0",
)
)
)

@mock.patch("builtins.print")
def test_available(self, output):
self.sample1.is_processed = True
self.sample1.save()
self.sample2.is_processed = True
self.sample2.save()

self.call_command(experiments=("GSE12417",))
output.assert_called_once_with(
"\n".join(
(
"Experiments attempted: 1",
"GEO: 1",
"",
"Samples attempted: 2",
"GEO: 2",
"",
"Experiments available: 1",
"GEO: 1",
"",
"Samples available: 2",
"GEO: 2",
"",
"Total jobs: 0",
)
)
)

@mock.patch("builtins.print")
def test_multiple(self, output):
self.sample1.is_processed = True
self.sample1.save()
self.sample2.is_processed = True
self.sample2.save()

experiment = Experiment.objects.create(
accession_code="GSE12418",
technology="MICROARRAY",
source_database="ARRAY_EXPRESS",
)
sample3 = Sample.objects.create(
accession_code="GSM311760",
is_processed=True,
technology="MICROARRAY",
)
sample4 = Sample.objects.create(
accession_code="GSM311761",
)
ExperimentSampleAssociation.objects.create(experiment=experiment, sample=sample3)
ExperimentSampleAssociation.objects.create(experiment=experiment, sample=sample4)

survey_job = SurveyJob(source_type="GEO")
survey_job.save()

# Jobs.
SurveyJobKeyValue(
survey_job=survey_job,
key="experiment_accession_code",
value=experiment.accession_code,
).save()

original_file = OriginalFile()
original_file.save()

original_file_sample_association = OriginalFileSampleAssociation()
original_file_sample_association.sample = sample3
original_file_sample_association.original_file = original_file
original_file_sample_association.save()

downloader_job = DownloaderJob()
downloader_job.save()

download_association = DownloaderJobOriginalFileAssociation()
download_association.original_file = original_file
download_association.downloader_job = downloader_job
download_association.save()

processor_job = ProcessorJob(downloader_job=downloader_job)
processor_job.save()

processor_association = ProcessorJobOriginalFileAssociation()
processor_association.original_file = original_file
processor_association.processor_job = processor_job
processor_association.save()

self.call_command(experiments=("GSE12417", "GSE12418"))
output.assert_called_once_with(
"\n".join(
(
"Experiments attempted: 2",
"ARRAY_EXPRESS: 1, GEO: 1",
"",
"Samples attempted: 4",
"ARRAY_EXPRESS: 2, GEO: 2",
"",
"Experiments available: 2",
"ARRAY_EXPRESS: 1, GEO: 1",
"",
"Samples available: 3",
"ARRAY_EXPRESS: 1, GEO: 2",
"",
"Total jobs: 3",
)
)
)

0 comments on commit bc365b2

Please sign in to comment.