-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference_claude.py
95 lines (73 loc) · 3.48 KB
/
inference_claude.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
import os
import sys
import csv
import time
import json
import argparse
import pandas as pd
import anthropic
# dataset = "macro_indicator"
dataset = "firm_news"
data_path = f"./data/prompt/{dataset}.json"
client = anthropic.Anthropic(
# defaults to os.environ.get("ANTHROPIC_API_KEY")
api_key="your_api_key",
)
# Function to load JSON data
def load_prompts_from_json(file_path):
with open(file_path, 'r') as json_file:
data = json.load(json_file)
return data
# Function to extract prediction
def extract_prediction(text):
prediction_mapping = {
'Strongly Bullish': 3,
'Bullish': 2,
'Slightly Bullish': 1,
'Flat': 0,
'Fluctuating': 0,
'Slightly Bearish': -1,
'Bearish': -2,
'Strongly Bearish': -3
}
match = re.search(r'\b(Strongly Bullish|Bullish|Slightly Bullish|Flat|Fluctuating|Slightly Bearish|Bearish|Strongly Bearish)\b', text, re.IGNORECASE)
if match:
return prediction_mapping.get(match.group(1))
return None
def main():
# Prepare CSV file for output
output_csv_file = f"./data/output/shuffle1/{dataset}_claude-3-5-sonnet-20240620_greedy.csv"
# Load the prompts from the JSON file
prompts_dict = load_prompts_from_json(data_path)
fieldnames = ['Date', 'Response', 'Prediction']
with open(output_csv_file, 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if csvfile.tell() == 0: # Check if file is empty to write header
writer.writeheader()
for date, prompt in prompts_dict.items():
# shuffle1
prompt = prompt.replace("Only output: Strongly Bullish, Bullish, Slightly Bullish, Flat, Fluctuating, Slightly Bearish, Bearish, Strongly Bearish", "Only output: Flat, Slightly Bearish, Bullish, Bearish, Fluctuating, Slightly Bullish, Strongly Bullish, Strongly Bearish")
# # shuffle2
# prompt = prompt.replace("Only output: Strongly Bullish, Bullish, Slightly Bullish, Flat, Fluctuating, Slightly Bearish, Bearish, Strongly Bearish", "Only output: Strongly Bearish, Flat, Slightly Bearish, Slightly Bullish, Fluctuating, Bullish, Bearish, Strongly Bullish")
# # shuffle3
# prompt = prompt.replace("Only output: Strongly Bullish, Bullish, Slightly Bullish, Flat, Fluctuating, Slightly Bearish, Bearish, Strongly Bearish", "Only output: Slightly Bearish, Strongly Bullish, Bearish, Bullish, Slightly Bullish, Fluctuating, Flat, Strongly Bearish")
# # shuffle4
# prompt = prompt.replace("Only output: Strongly Bullish, Bullish, Slightly Bullish, Flat, Fluctuating, Slightly Bearish, Bearish, Strongly Bearish", "Only output: Bearish, Strongly Bullish, Fluctuating, Slightly Bearish, Slightly Bullish, Strongly Bearish, Bullish, Flat")
# prompt style
message = client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=2048,
temperature=0.0,
system="",
messages=[
{"role": "user", "content": prompt}
]
)
response = message.content[0].text
prediction = extract_prediction(response)
# Write to CSV file
writer.writerow({'Date': date, 'Response': response, 'Prediction': prediction})
csvfile.flush() # Flush data to disk after each write
if __name__ == "__main__":
main()