diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 00000000..383e65cd --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..70a0ec10 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,456 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=docs,parallel_utils + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat-in-sequence, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + logging-fstring-interpolation, # added by vLLM + logging-not-lazy, # added by vLLM + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-class-docstring, # TODO (vLLM): enable + missing-function-docstring, + missing-module-docstring, # TODO (vLLM): enable + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + unspecified-encoding, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + protected-access, + unsubscriptable-object, + invalid-all-object, + redefined-builtin, + consider-using-in, + inconsistent-quotes, + no-self-argument, + broad-except, + unused-variable, + unreachable, + unused-argument, + funtion-redefined, + use-a-generator, + unidiomatic-typecheck, + invalid-name, + singleton-comparison, + access-member-before-definition, + use-dict-literal, + consider-using-generator, + unnecessary-dunder-call, + import-outside-toplevel + + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=180 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException diff --git a/benchmark/all_gather_benchmark.py b/benchmark/all_gather_benchmark.py index 01b18dae..1c9a8f22 100644 --- a/benchmark/all_gather_benchmark.py +++ b/benchmark/all_gather_benchmark.py @@ -5,7 +5,7 @@ from tqdm import tqdm import ray -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from benchmark.all_gather_wrapper import AllGatherWrapper @@ -99,7 +99,7 @@ def run_benchmark(): df = pd.DataFrame(all_results) # the time_stats column is a dict, so we need to expand it into columns recursively and add prefix - + df = pd.json_normalize(df["time_stats"]).add_prefix("time_stats.").join(df.drop(columns=["time_stats"])) # write results to a csv file diff --git a/benchmark/all_gather_wrapper.py b/benchmark/all_gather_wrapper.py index 9f556898..59bc7c84 100644 --- a/benchmark/all_gather_wrapper.py +++ b/benchmark/all_gather_wrapper.py @@ -89,7 +89,7 @@ def profile(self): self._run_all_reduce() TimerStatsStore.clear_stats() - + for _ in range(ACTIVE_STEPS): self._run_all_reduce() diff --git a/benchmark/all_reduce_benchmark.py b/benchmark/all_reduce_benchmark.py index 06a08f19..3668a22d 100644 --- a/benchmark/all_reduce_benchmark.py +++ b/benchmark/all_reduce_benchmark.py @@ -7,7 +7,7 @@ from tqdm import tqdm import ray -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from benchmark.all_reduce_wrapper import AllReduceWrapper @@ -38,7 +38,7 @@ NUM_TOKENS = \ list(range(0, 128, 8)) + \ list(range(128, 1536, 8)) + \ - list(range(1536, 98 * 1024, 256)) + list(range(1536, 98 * 1024, 256)) # + \ # list(range(98 * 1024, 196 * 1024, 512)) # NUM_TOKENS = ( @@ -94,7 +94,7 @@ def run_benchmark(): for n_workers in NUM_TENSOR_PARALLEL_WORKERS: params = itertools.product(EMBEDDING_DIMS, NUM_TOKENS) - + del runner_pool gc.collect() @@ -140,7 +140,7 @@ def run_benchmark(): df = pd.DataFrame(all_results) # the time_stats column is a dict, so we need to expand it into columns recursively and add prefix - + df = pd.json_normalize(df["time_stats"]).add_prefix("time_stats.").join(df.drop(columns=["time_stats"])) # write results to a csv file diff --git a/benchmark/all_reduce_wrapper.py b/benchmark/all_reduce_wrapper.py index f5414606..c18efd47 100644 --- a/benchmark/all_reduce_wrapper.py +++ b/benchmark/all_reduce_wrapper.py @@ -1,11 +1,12 @@ import torch +import os import numpy as np from benchmark.cuda_timer import CudaTimer from benchmark.timer_stats_store import TimerStatsStore -from vllm.all_reduce_ops import init_nccl, all_reduce +#from vllm.all_reduce_ops import init_nccl, all_reduce WARMUP_STEPS = 5 @@ -85,7 +86,6 @@ def _init_communication(self, comm_id): if torch.distributed.is_initialized(): return - import os print(f"Rank: {self._rank}, num_workers: {self._num_workers}, comm_id: {comm_id}") print("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" diff --git a/benchmark/mixed_attention_benchmark.py b/benchmark/mixed_attention_benchmark.py index 03ba8060..8bf20c2d 100644 --- a/benchmark/mixed_attention_benchmark.py +++ b/benchmark/mixed_attention_benchmark.py @@ -3,7 +3,7 @@ import os import ray -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from tqdm import tqdm diff --git a/benchmark/mixed_attention_wrapper.py b/benchmark/mixed_attention_wrapper.py index 18dcfd18..7168c1ab 100644 --- a/benchmark/mixed_attention_wrapper.py +++ b/benchmark/mixed_attention_wrapper.py @@ -3,7 +3,7 @@ import torch -from benchmark.cuda_timer import CudaTimer +#from benchmark.cuda_timer import CudaTimer from benchmark.timer_stats_store import TimerStatsStore from benchmark.vllm_attention import PagedAttentionWithRoPE, InputMetadata @@ -57,7 +57,7 @@ def __init__( self._attn_base = 10000 # default from vllm self._max_position = 8192 # default from vllm - + self.attn = PagedAttentionWithRoPE( self._n_worker_q_heads, self._head_dim, @@ -68,7 +68,7 @@ def __init__( num_kv_heads=self._n_worker_kv_heads, ).to(dtype=torch.float16).cuda().eval() # .to(dtype=torch.float16).cuda() - + self._blocks_per_sequence = ceil(max_context_len / block_size) self._total_num_blocks = max(10000, 1 + batch_size * self._blocks_per_sequence) self._k_cache_split_factor = 16 // torch.tensor([], dtype=torch.float16).element_size() diff --git a/benchmark/mlp_benchmark.py b/benchmark/mlp_benchmark.py index 31c1ebda..ee52ee62 100644 --- a/benchmark/mlp_benchmark.py +++ b/benchmark/mlp_benchmark.py @@ -4,7 +4,7 @@ import ray import torch -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from tqdm import tqdm @@ -150,7 +150,7 @@ def run_benchmark(): NUM_TOKENS, NUM_TENSOR_PARALLEL_WORKERS, ) - + for ( num_tokens, num_tensor_parallel_workers, diff --git a/benchmark/nanogpt.py b/benchmark/nanogpt.py index ebfc52c5..00093468 100644 --- a/benchmark/nanogpt.py +++ b/benchmark/nanogpt.py @@ -7,13 +7,13 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py """ -import math -import inspect +#import math +#import inspect from dataclasses import dataclass from math import ceil import torch -import torch.nn as nn +from torch import nn from torch.nn import functional as F from benchmark.cuda_timer import CudaTimer @@ -86,7 +86,7 @@ def __init__(self, config): self.n_head = config.n_head self.n_embd = config.n_embd - + self.attn_pre_proj_timer = CudaTimer("attn_pre_proj") self.attn_post_proj_timer = CudaTimer("attn_post_proj") @@ -125,11 +125,11 @@ def __init__(self, config): self.mlp_up_proj_timer = CudaTimer("mlp_up_proj") self.mlp_act_timer = CudaTimer("mlp_act") self.mlp_down_proj_timer = CudaTimer("mlp_down_proj") - + def forward(self, x): with self.mlp_up_proj_timer: x = self.c_fc(x) - + with self.mlp_act_timer: x = self.act(x) @@ -153,9 +153,9 @@ def __init__(self, config): self.emb_timer = CudaTimer("emb") self.layer_norm_timer = CudaTimer("layer_norm") self.rms_norm_timer = CudaTimer("rms_norm") - self.deemb_timer = CudaTimer("deemb") + self.deemb_timer = CudaTimer("deemb") self.add_norm_timer = CudaTimer("add_norm") - + def forward(self, x): with self.emb_timer: x = self.emb(x) diff --git a/benchmark/offload_benchmark.py b/benchmark/offload_benchmark.py index b0277ff8..0ae470ac 100644 --- a/benchmark/offload_benchmark.py +++ b/benchmark/offload_benchmark.py @@ -4,7 +4,7 @@ import random import ray -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from benchmark.offload_wrapper import OffloadWrapper @@ -105,7 +105,7 @@ def run_benchmark(): df = pd.DataFrame(all_results) # the time_stats column is a dict, so we need to expand it into columns recursively and add prefix - + df = pd.json_normalize(df["time_stats"]).add_prefix("time_stats.").join(df.drop(columns=["time_stats"])) # write results to a csv file diff --git a/benchmark/p2p_benchmark.py b/benchmark/p2p_benchmark.py index b5ece0b5..26a63b1e 100644 --- a/benchmark/p2p_benchmark.py +++ b/benchmark/p2p_benchmark.py @@ -5,7 +5,7 @@ from tqdm import tqdm import ray -import torch.backends.cudnn as cudnn +from torch.backends import cudnn import pandas as pd from benchmark.p2p_wrapper import P2PWrapper @@ -36,7 +36,7 @@ NUM_TOKENS = \ list(range(0, 128, 1)) + \ list(range(128, 1536, 4)) + \ - slist(range(1536, 98 * 1024, 256)) + list(range(1536, 98 * 1024, 256)) # + \ # list(range(98 * 1024, 196 * 1024, 512)) NUM_TENSOR_PARALLEL_WORKERS = [1, 2, 4] @@ -103,11 +103,11 @@ def run_benchmark(): if result and rank == 0: all_results.append(result) if not result: - runner_pool[rank] = ModelRunner.remote() + runner_pool[rank] = ModelRunner.remote() df = pd.DataFrame(all_results) # the time_stats column is a dict, so we need to expand it into columns recursively and add prefix - + df = pd.json_normalize(df["time_stats"]).add_prefix("time_stats.").join(df.drop(columns=["time_stats"])) # write results to a csv file diff --git a/benchmark/p2p_wrapper.py b/benchmark/p2p_wrapper.py index ba1ed394..1e74818d 100644 --- a/benchmark/p2p_wrapper.py +++ b/benchmark/p2p_wrapper.py @@ -1,4 +1,5 @@ import torch +import numpy as np from benchmark.cuda_timer import CudaTimer from benchmark.timer_stats_store import TimerStatsStore @@ -78,7 +79,6 @@ def __init__(self, rank, comm_id, n_embd, num_tokens, num_tensor_parallel_worker self._cuda_timer = CudaTimer("send_recv", aggregation_fn=np.median, filter_str="ncclKernel") - def _init_communication(self, rank, comm_id): # skip if already initialized if torch.distributed.is_initialized(): @@ -92,7 +92,7 @@ def _init_communication(self, rank, comm_id): # init_method=f"tcp://node-0:{comm_id}", init_method=f"file:///tmp/sing_comm_{comm_id}", ) - print(f"Initialized process group.") + print("Initialized process group.") def _run_send_recv(self): torch.cuda.synchronize() diff --git a/benchmark/vllm_attention.py b/benchmark/vllm_attention.py index 243f9bef..a0ae1436 100644 --- a/benchmark/vllm_attention.py +++ b/benchmark/vllm_attention.py @@ -1,8 +1,8 @@ """Multi-head attention.""" -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch -import torch.nn as nn +from torch import nn from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalFromBottomRightMask, diff --git a/simulator/config/config.py b/simulator/config/config.py index e454ac46..0ccbc2b3 100644 --- a/simulator/config/config.py +++ b/simulator/config/config.py @@ -1,6 +1,6 @@ import argparse import datetime -import hashlib +#import hashlib import os import yaml diff --git a/simulator/entities/batch.py b/simulator/entities/batch.py index 7f35ba15..7169fad9 100644 --- a/simulator/entities/batch.py +++ b/simulator/entities/batch.py @@ -99,7 +99,7 @@ def request_ids(self) -> List[int]: return [request.id for request in self._requests] @property - def completed(self) -> bool: + def allcompleted(self) -> bool: return all([request.completed for request in self._requests]) def on_schedule( diff --git a/simulator/entities/cluster.py b/simulator/entities/cluster.py index 3bed0e66..eee70abb 100644 --- a/simulator/entities/cluster.py +++ b/simulator/entities/cluster.py @@ -38,4 +38,5 @@ def _write_cluster_info_to_file(self) -> None: cluster_info = {"replicas": replica_dicts} cluster_file = f"{self._config.output_dir}/cluster.json" - json.dump(cluster_info, open(cluster_file, "w")) + with open(cluster_file, "w") as fd: + json.dump(cluster_info, fd) diff --git a/simulator/events/batch_end_event.py b/simulator/events/batch_end_event.py index 3e416efa..83115f4b 100644 --- a/simulator/events/batch_end_event.py +++ b/simulator/events/batch_end_event.py @@ -6,6 +6,7 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.replica_schedule_event import ReplicaScheduleEvent logger = logging.getLogger(__name__) @@ -23,7 +24,6 @@ def event_type(self): def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.replica_schedule_event import ReplicaScheduleEvent self._batch.on_batch_end(self.time) replica_scheduler = scheduler.get_replica_scheduler(self._replica_id) diff --git a/simulator/events/batch_stage_arrival_event.py b/simulator/events/batch_stage_arrival_event.py index 3c9b4053..a78ee1d3 100644 --- a/simulator/events/batch_stage_arrival_event.py +++ b/simulator/events/batch_stage_arrival_event.py @@ -6,6 +6,9 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.replica_stage_schedule_event import ( + ReplicaStageScheduleEvent, +) logger = logging.getLogger(__name__) @@ -25,9 +28,6 @@ def event_type(self): def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.replica_stage_schedule_event import ( - ReplicaStageScheduleEvent, - ) scheduler.get_replica_stage_scheduler( self._replica_id, self._stage_id diff --git a/simulator/events/batch_stage_end_event.py b/simulator/events/batch_stage_end_event.py index d1323c43..c8eb83a5 100644 --- a/simulator/events/batch_stage_end_event.py +++ b/simulator/events/batch_stage_end_event.py @@ -42,7 +42,6 @@ def handle_event( from simulator.events.replica_stage_schedule_event import ( ReplicaStageScheduleEvent, ) - scheduler.get_replica_stage_scheduler( self._replica_id, self._stage_id ).on_stage_end() diff --git a/simulator/events/global_schedule_event.py b/simulator/events/global_schedule_event.py index 07b66c70..b9cf8fc7 100644 --- a/simulator/events/global_schedule_event.py +++ b/simulator/events/global_schedule_event.py @@ -5,6 +5,7 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.replica_schedule_event import ReplicaScheduleEvent logger = logging.getLogger(__name__) @@ -22,7 +23,6 @@ def event_type(self): def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.replica_schedule_event import ReplicaScheduleEvent self._replica_set = set() self._request_mapping = scheduler.schedule() diff --git a/simulator/events/replica_schedule_event.py b/simulator/events/replica_schedule_event.py index 2fb940b4..d30efd10 100644 --- a/simulator/events/replica_schedule_event.py +++ b/simulator/events/replica_schedule_event.py @@ -5,6 +5,7 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.batch_stage_arrival_event import BatchStageArrivalEvent logger = logging.getLogger(__name__) @@ -24,7 +25,6 @@ def event_type(self): def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.batch_stage_arrival_event import BatchStageArrivalEvent replica_scheduler = scheduler.get_replica_scheduler(self._replica_id) self._batches = replica_scheduler.on_schedule() diff --git a/simulator/events/replica_stage_schedule_event.py b/simulator/events/replica_stage_schedule_event.py index 4f797cf9..e7d7e809 100644 --- a/simulator/events/replica_stage_schedule_event.py +++ b/simulator/events/replica_stage_schedule_event.py @@ -5,6 +5,7 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.batch_stage_end_event import BatchStageEndEvent logger = logging.getLogger(__name__) @@ -27,7 +28,6 @@ def event_type(self): def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.batch_stage_end_event import BatchStageEndEvent stage_scheduler = scheduler.get_replica_stage_scheduler( self._replica_id, self._stage_id diff --git a/simulator/events/request_arrival_event.py b/simulator/events/request_arrival_event.py index 855529fc..68135980 100644 --- a/simulator/events/request_arrival_event.py +++ b/simulator/events/request_arrival_event.py @@ -6,6 +6,7 @@ from simulator.plotting import MetricsStore from simulator.scheduler import BaseGlobalScheduler from simulator.types import EventType +from simulator.events.global_schedule_event import GlobalScheduleEvent logger = logging.getLogger(__name__) @@ -22,7 +23,6 @@ def event_type(self) -> EventType: def handle_event( self, scheduler: BaseGlobalScheduler, metrics_store: MetricsStore ) -> List[BaseEvent]: - from simulator.events.global_schedule_event import GlobalScheduleEvent logger.debug(f"Request: {self._request.id} arrived at {self.time}") scheduler.add_request(self._request) diff --git a/simulator/execution_time_predictor/base_execution_time_predictor.py b/simulator/execution_time_predictor/base_execution_time_predictor.py index 4bd1dcb1..28f094aa 100644 --- a/simulator/execution_time_predictor/base_execution_time_predictor.py +++ b/simulator/execution_time_predictor/base_execution_time_predictor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +#from typing import List from simulator.config import Config from simulator.entities import Batch diff --git a/simulator/execution_time_predictor/sklearn_execution_time_predictor.py b/simulator/execution_time_predictor/sklearn_execution_time_predictor.py index b8759a1a..a94ed9fc 100644 --- a/simulator/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/simulator/execution_time_predictor/sklearn_execution_time_predictor.py @@ -183,7 +183,8 @@ def _load_model_from_cache(self, model_name: str) -> BaseEstimator: return logger.info(f"Found model {model_name} in cache") - model = pickle.load(open(cache_file, "rb")) + with open(cache_file, "rb") as fd: + model = pickle.load(fd) return model def _store_model_in_cache(self, model_name: str, model: BaseEstimator) -> None: @@ -191,7 +192,8 @@ def _store_model_in_cache(self, model_name: str, model: BaseEstimator) -> None: # store model in cache cache_file = f"{self._cache_dir}/{model_name_hash}.pkl" - pickle.dump(model, open(cache_file, "wb")) + with open(cache_file, "wb") as fd: + pickle.dump(model, fd) def _store_training_prediction_data( self, @@ -264,10 +266,12 @@ def _store_model_predication_cache( model_name_hash = self._get_model_name_hash(model_name) cache_file = f"{self._cache_dir}/{model_name_hash}_predictions.pkl" json_file = f"{self._cache_dir}/{model_name}_{model_name_hash}_predictions.json" - pickle.dump(predictions, open(cache_file, "wb")) + with open(cache_file, "wb") as fd: + pickle.dump(predictions, fd) # convert keys from tuple to string json_serializable_predictions = {str(x): y for x, y in predictions.items()} - json.dump(json_serializable_predictions, open(json_file, "w")) + with open(json_file, "w") as fd: + json.dump(json_serializable_predictions, fd) def _load_model_predication_cache(self, model_name: str) -> Dict[Tuple, float]: if self._no_cache: @@ -279,7 +283,8 @@ def _load_model_predication_cache(self, model_name: str) -> Dict[Tuple, float]: if not os.path.exists(cache_file): return - predictions = pickle.load(open(cache_file, "rb")) + with open(cache_file, "rb") as fd: + predictions = pickle.load(fd) return predictions def _get_model_prediction( diff --git a/simulator/plotting/data_series.py b/simulator/plotting/data_series.py index 8a055bff..f1f8522c 100644 --- a/simulator/plotting/data_series.py +++ b/simulator/plotting/data_series.py @@ -18,7 +18,7 @@ def __init__( save_table_to_wandb: bool = True, ) -> None: # metrics are a data series of two-dimensional (x, y) datapoints - self._data_series = list() + self._data_series = [] # column names of x, y datatpoints for data collection self._x_name = x_name self._y_name = y_name diff --git a/simulator/plotting/metrics_store.py b/simulator/plotting/metrics_store.py index 939f4a87..2b14c69f 100644 --- a/simulator/plotting/metrics_store.py +++ b/simulator/plotting/metrics_store.py @@ -77,7 +77,7 @@ def __init__(self, config: Config): metric_name.value, self._subsamples, self._save_table_to_wandb, - ) + ) self._req_metrics_histogram: Dict[RequestMetricsHistogram, DataSeries] = {} for metric_name in RequestMetricsHistogram: diff --git a/simulator/request_generator/base_request_generator.py b/simulator/request_generator/base_request_generator.py index 4c3e1dff..b5e6a8f2 100644 --- a/simulator/request_generator/base_request_generator.py +++ b/simulator/request_generator/base_request_generator.py @@ -14,7 +14,8 @@ def __init__(self, config: Config): def _write_requests_to_file(self, requests: List[Request]) -> None: request_dicts = [request.to_dict() for request in requests] request_file = f"{self._config.output_dir}/requests.json" - json.dump(request_dicts, open(request_file, "w")) + with open(request_file, "w") as fd: + json.dump(request_dicts, fd) @abstractmethod def generate_requests(self) -> List[Request]: diff --git a/simulator/request_generator/gamma_request_interval_generator.py b/simulator/request_generator/gamma_request_interval_generator.py index e4161651..d8b8334a 100644 --- a/simulator/request_generator/gamma_request_interval_generator.py +++ b/simulator/request_generator/gamma_request_interval_generator.py @@ -1,4 +1,4 @@ -import random +#import random from scipy.stats import gamma diff --git a/simulator/request_generator/zipf_request_length_generator.py b/simulator/request_generator/zipf_request_length_generator.py index 45285c23..1dff8d8e 100644 --- a/simulator/request_generator/zipf_request_length_generator.py +++ b/simulator/request_generator/zipf_request_length_generator.py @@ -1,4 +1,4 @@ -import random +#import random from typing import Tuple from simulator.request_generator.base_request_length_generator import ( diff --git a/simulator/scheduler/replica_scheduler/faster_transformer_replica_scheduler.py b/simulator/scheduler/replica_scheduler/faster_transformer_replica_scheduler.py index 8adcf468..d29fa4f1 100644 --- a/simulator/scheduler/replica_scheduler/faster_transformer_replica_scheduler.py +++ b/simulator/scheduler/replica_scheduler/faster_transformer_replica_scheduler.py @@ -1,6 +1,7 @@ -from typing import List +#from typing import List -from simulator.entities.batch import Batch, Request +#from simulator.entities.batch import Batch, Request +from simulator.entities.batch import Batch from simulator.scheduler.replica_scheduler.base_replica_scheduler import ( BaseReplicaScheduler, ) diff --git a/simulator/simulator.py b/simulator/simulator.py index 2e5a3efa..6324b0ca 100644 --- a/simulator/simulator.py +++ b/simulator/simulator.py @@ -105,11 +105,13 @@ def _set_time(self, time: float) -> None: def _write_event_trace(self) -> None: trace_file = f"{self._config.output_dir}/event_trace.json" - json.dump(self._event_trace, open(trace_file, "w")) + with open(trace_file, "w") as fd: + json.dump(self._event_trace, fd) def _write_chrome_trace(self) -> None: trace_file = f"{self._config.output_dir}/chrome_trace.json" chrome_trace = {"traceEvents": self._event_chrome_trace} - json.dump(chrome_trace, open(trace_file, "w")) + with open(trace_file, "w") as fd: + json.dump(chrome_trace, fd)