Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fb-dia-1018' into fb-dia-953/str…
Browse files Browse the repository at this point in the history
…eam-results
  • Loading branch information
matt-bernstein committed Apr 9, 2024
2 parents 1985725 + 03db524 commit 5b4e508
Show file tree
Hide file tree
Showing 37 changed files with 976 additions and 1,415 deletions.
28 changes: 26 additions & 2 deletions .github/workflows/cicd-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
push:
branches:
- master
pull_request_target:
pull_request:
types:
- opened
- synchronize
Expand All @@ -19,6 +19,29 @@ concurrency:

jobs:

details:
name: "Details"
runs-on: ubuntu-latest
outputs:
membership: ${{ steps.membership.outputs.membership }}
steps:
- name: Check user's membership
uses: actions/github-script@v7
id: membership
env:
ACTOR: ${{ github.actor }}
with:
github-token: ${{ secrets.GIT_PAT }}
script: |
const { owner } = context.repo;
const actor = process.env.ACTOR;
github.rest.orgs.getMembershipForUser({
org: owner,
username: actor,
})
.then(response => core.setOutput("membership", response.data))
.catch(reason => core.setOutput("membership", false));
build:
name: "Build"
uses: ./.github/workflows/docker-build.yml
Expand All @@ -32,9 +55,10 @@ jobs:

deploy:
name: "Deploy"
if: github.event_name == 'pull_request_target'
if: github.event_name == 'pull_request' && needs.details.outputs.membership
uses: ./.github/workflows/argocd-create-app.yml
needs:
- details
- build
with:
docker_image_version: ${{ needs.build.outputs.image_version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
echo "image_version=$version" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3.2.0
uses: docker/setup-buildx-action@v3.3.0

- name: Login to DockerHub
uses: docker/[email protected]
Expand Down
88 changes: 88 additions & 0 deletions .github/workflows/draft-pr-converter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
name: "Convert stale PR's to drafts"

on:
schedule:
- cron: '0 */1 * * *' # At minute 0 past every hour.
workflow_dispatch:

env:
DRAFT_PROTECT_LABEL: "draft-protect"

jobs:
convert_stale_prs:
runs-on: ubuntu-latest
timeout-minutes: 2
steps:
- uses: hmarr/[email protected]

- id: get_timestamp
name: Get timestamp
shell: bash
run: echo "ts=$(date -d '10 hours ago' +"%Y-%m-%dT%H:%M:%S")" >> $GITHUB_OUTPUT

- uses: octokit/[email protected]
name: Get PR older than 10 hours
id: get_stale_prs
env:
GITHUB_TOKEN: ${{ secrets.GIT_PAT }}
with:
query: |
{
search(query: "repo:${{ github.repository }} is:pr is:open draft:false -label:${{ env.DRAFT_PROTECT_LABEL }} updated:<=${{ steps.get_timestamp.outputs.ts }}", type: ISSUE, first: 100) {
issueCount
edges {
node {
... on PullRequest {
number
url
id
updatedAt
}
}
}
}
}
- name: Stale PRs data
run: "echo '${{ steps.get_stale_prs.outputs.data }}'"

- name: Convert to draft
id: mutation_step
shell: bash
env:
GIT_PAT: ${{ secrets.GIT_PAT }}
run: |
set -eux
echo '${{ steps.get_stale_prs.outputs.data }}' > /tmp/stale_pr.json
_pr_list=$(jq -r '.search.edges | map(.node.url) | join("\\n")' < /tmp/stale_pr.json)
if [ -n "$_pr_list" ]; then
echo "pr_list=$_pr_list" >> $GITHUB_OUTPUT
echo "exec=true" >> $GITHUB_OUTPUT
echo "$GIT_PAT" | gh auth login --with-token
for pr_id in $(jq -r '.search.edges[].node.id' < /tmp/stale_pr.json); do
gh api graphql -F id="${pr_id}" -f query='
mutation($id: ID!) {
convertPullRequestToDraft(input: { pullRequestId: $id }) {
pullRequest {
id
number
isDraft
}
}
}
'
done
fi
- name: Post to a Slack channel
id: slack
if: ${{ steps.mutation_step.outputs.exec == 'true' }}
uses: slackapi/[email protected]
with:
channel-id: 'C02LMULF4NA'
payload: '{ "type": "mrkdwn", "text": "${{ env.SLACK_MESSAGE }}" }'
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_LSE_BOT_TOKEN }}
SLACK_MESSAGE: >-
*Drafted PR's*\n
${{ steps.mutation_step.outputs.pr_list }}
9 changes: 4 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- 'tests/**'
- '.github/workflows/tests.yml'
- '**/requirements**'
- 'codecov.yml'
tags-ignore:
- '**'
pull_request:
Expand Down Expand Up @@ -86,19 +87,17 @@ jobs:
if: ${{ matrix.os != 'ubuntu-latest' || matrix.python-version != '3.11' }}
run: |
source $VENV
cd tests/
poetry run pytest --junitxml report.xml --cov=. -vv
poetry run pytest tests/ -vv
- name: Run tests with coverage
if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' }}
run: |
source $VENV
cd tests/
poetry run pytest --junitxml report.xml --cov=. -vv
poetry run pytest tests/ --cov=. --cov-report=xml -vv
- name: Upload to Codecov
if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' }}
uses: codecov/codecov-action@v4.1.0
uses: codecov/codecov-action@v4.2.0
with:
files: tests/coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,7 @@ pyrightconfig.json
# actions-hub
.github/actions-hub

.idea/
.idea/

# server config
server/.env
2 changes: 1 addition & 1 deletion Dockerfile.app
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
FROM python:3.11-slim

# Install git
RUN apt-get update && apt-get install -y git
RUN apt-get update && apt-get install -y git gcc

# Set environment variables
ENV PYTHONDONTWRITEBYTECODE 1
Expand Down
50 changes: 26 additions & 24 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import logging
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator, SerializeAsAny
from pydantic import (
BaseModel,
Field,
SkipValidation,
field_validator,
model_validator,
SerializeAsAny,
)
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print
Expand All @@ -9,7 +16,6 @@
from adala.environments.static_env import StaticEnvironment
from adala.runtimes.base import Runtime, AsyncRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.runtimes import GuidanceRuntime
from adala.skills._base import Skill
from adala.memories.base import Memory
from adala.skills.skillset import SkillSet, LinearSkillSet
Expand Down Expand Up @@ -53,24 +59,12 @@ class Agent(BaseModel, ABC):

memory: Memory = Field(default=None)
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {
"default": GuidanceRuntime()
# 'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
# 'llama2': LLMRuntime(
# llm_runtime_type=LLMRuntimeModelType.Transformers,
# llm_params={
# 'model': 'meta-llama/Llama-2-7b',
# 'device': 'cuda:0',
# }
# )
}
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
)
default_runtime: str = "default"
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
default_factory=lambda: {
"default": None
}
default_factory=lambda: {"default": None}
)
default_runtime: str = "default"
default_teacher_runtime: str = "default"

class Config:
Expand Down Expand Up @@ -121,9 +115,11 @@ def skills_validator(cls, v) -> SkillSet:
elif isinstance(v, list):
return LinearSkillSet(skills=v)
else:
raise ValueError(f"skills must be of type SkillSet or Skill, but received type {type(v)}")
raise ValueError(
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
)

@field_validator('runtimes', mode='before')
@field_validator("runtimes", mode="before")
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
Validates and creates runtimes
Expand All @@ -136,7 +132,9 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
)
type_name = runtime_value.pop("type")
runtime_value = Runtime.create_from_registry(type=type_name, **runtime_value)
runtime_value = Runtime.create_from_registry(
type=type_name, **runtime_value
)
out[runtime_name] = runtime_value
return out

Expand Down Expand Up @@ -209,9 +207,11 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
raise ValueError(f'Teacher Runtime "{runtime}" not found.')
runtime = self.teacher_runtimes[runtime]
if not runtime:
raise ValueError(f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})")
raise ValueError(
f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})"
)
return runtime

def run(
Expand Down Expand Up @@ -269,7 +269,9 @@ async def arun(
# run on the environment until it is exhausted
while True:
try:
data_batch = await self.environment.get_data_batch(batch_size=runtime.batch_size)
data_batch = await self.environment.get_data_batch(
batch_size=runtime.batch_size
)
if data_batch.empty:
print_text("No more data in the environment. Exiting.")
break
Expand Down
1 change: 0 additions & 1 deletion adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class Config:


class AsyncEnvironment(Environment, ABC):

@abstractmethod
async def initialize(self):
"""
Expand Down
Loading

0 comments on commit 5b4e508

Please sign in to comment.