Skip to content

Commit 9143767

Browse files
authored
feat: visualization generation instruction and fewshots (#19)
1 parent 0e94669 commit 9143767

File tree

9 files changed

+164
-2
lines changed

9 files changed

+164
-2
lines changed

coml/core.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
FIX_INSTRUCTION,
1919
GENERATE_INSTRUCTION,
2020
GENERATE_INSTRUCTION_COT,
21+
GENERATE_INSTRUCTION_VIS_MATPLOTLIB,
22+
GENERATE_INSTRUCTION_VIS_SEABORN,
2123
SANITY_CHECK_INSTRUCTION,
2224
SUGGEST_INSTRUCTION,
2325
FixContext,
@@ -125,7 +127,9 @@ class CoMLAgent:
125127
def __init__(
126128
self,
127129
llm: BaseChatModel,
128-
prompt_version: Literal["v1", "v2", "kaggle", "leetcode"] = "v2",
130+
prompt_version: Literal[
131+
"v1", "v2", "kaggle", "leetcode", "matplotlib", "seaborn"
132+
] = "v2",
129133
prompt_validation: Callable[[list[BaseMessage]], bool] | None = None,
130134
num_examples: float | int = 1.0,
131135
message_style: Literal["chatgpt", "gemini"] = "chatgpt",
@@ -298,6 +302,12 @@ def generate_code(
298302
shot["answer"] = shot.pop("answer_wo_intact")
299303
if "rationale_wo_intact" in shot:
300304
shot["rationale"] = shot.pop("rationale_wo_intact")
305+
306+
if self.prompt_version == "matplotlib":
307+
generate_instruction = GENERATE_INSTRUCTION_VIS_MATPLOTLIB
308+
elif self.prompt_version == "seaborn":
309+
generate_instruction = GENERATE_INSTRUCTION_VIS_SEABORN
310+
301311
messages.append(SystemMessage(content=generate_instruction))
302312

303313
for shot in self._select_examples(request, fewshots):

coml/prompt_utils.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,18 @@ def render_sanity_check_context(
425425
- Think before you write. You should first understand the user's request, think about how to achieve it, and then write the code.
426426
"""
427427

428+
GENERATE_INSTRUCTION_VIS_MATPLOTLIB = f"""You're a helpful assistant proficient in writing Python code for data visualization. Upon receiving relevant context, such as available variables and any pre-executed code, your goal is to complete the Python code to generate a visualization that meets the user's request.
429+
430+
Instructions:
431+
432+
- Utilize the `matplotlib` library to create the visualization and ensure include `plt.show()` to display the chart.
433+
- You must return the generated code wrapped by ``` before and after it, and do not add any explanation.
434+
"""
435+
436+
GENERATE_INSTRUCTION_VIS_SEABORN = GENERATE_INSTRUCTION_VIS_MATPLOTLIB.replace(
437+
"matplotlib", "seaborn"
438+
)
439+
428440
FIX_INSTRUCTION = f"""{GENERATE_INSTRUCTION.rstrip()}
429441
- If the user thinks the generated code is problematic, you should help fix it. The user will provide you with the exception message (if any), the output of the code (if any), and a hint (if any). You should provide a line-by-line explanation of the code, and point out what is wrong with the code. You should also provide the fixed code.
430442
- If you think the provided problematic code is actually correct, you should first explain the code, and write "THE CODE IS CORRECT." (in upper case) in the observation section. The fixed code can be omitted.
@@ -459,7 +471,20 @@ def cached_generate_fewshots(prompt_version: str) -> list[GenerateContext]:
459471
with open(
460472
Path(__file__).parent / f"prompts/generate_fewshots_{prompt_version}.json"
461473
) as f:
462-
return json.load(f)
474+
fewshots = json.load(f)
475+
for shot in fewshots:
476+
variables = {}
477+
if "datasets" in shot:
478+
for name in shot["datasets"]:
479+
dataset = pd.read_csv(
480+
str(Path(__file__).parent / f"prompts/dataset/{name}.csv")
481+
)
482+
# todo: dataframe_format
483+
variables[f"{name.split('/')[1]}_dataset"] = describe_variable(
484+
dataset
485+
)
486+
shot["variables"] = variables
487+
return fewshots
463488

464489

465490
def cached_fix_fewshots() -> list[FixContext]:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Member_ID,Name,Membership_card,Age,Time_of_purchase,Level_of_membership,Address
2+
1,"Ashby, Lazale",Black,29,18,5,Hartford
3+
2,"Breton, Robert",White,67,41,4,Waterbury
4+
3,"Campbell, Jessie",Black,34,20,6,Hartford
5+
4,"Cobb, Sedrick",Black,51,27,2,Waterbury
6+
5,"Hayes, Steven",White,50,44,3,Cheshire
7+
6,"Komisarjevsky, Joshua",White,33,26,2,Cheshire
8+
7,"Peeler, Russell",Black,42,26,6,Bridgeport
9+
8,"Reynolds, Richard",Black,45,24,1,Waterbury
10+
9,"Rizzo, Todd",White,35,18,4,Waterbury
11+
10,"Webb, Daniel",Black,51,27,22,Hartford
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
stadium_id,id,Season,Date,Home_team,Away_team,Score,Competition
2+
1,1,2007,18 May 2007,Quruvchi,Pakhtakor,1–1,League
3+
2,2,2007,22 September 2007,Pakhtakor,Quruvchi,0–0,League
4+
3,3,2007,9 December 2007,Pakhtakor,Quruvchi,0–0 (7:6),Cup
5+
4,4,2008,10 July 2008,Pakhtakor,Quruvchi,1–1,League
6+
5,5,2008,16 August 2008,Bunyodkor,Pakhtakor,1–1,League
7+
6,6,2008,31 October 2008,Bunyodkor,Pakhtakor,3–1,Cup
8+
7,7,2009,12 July 2009,Bunyodkor,Pakhtakor,2–1,League
9+
8,8,2009,14 October 2009,Pakhtakor,Bunyodkor,0–0,League
10+
9,9,2009,8 August 2009,Pakhtakor,Bunyodkor,1–0,Cup
11+
10,10,2010,14 March 2010,Bunyodkor,Pakhtakor,2–1,League
12+
10,11,2010,31 October 2010,Pakhtakor,Bunyodkor,0–0,League
13+
10,12,2011,7 July 2011,Pakhtakor,Bunyodkor,0–0,League
14+
1,13,2011,21 August 2011,Bunyodkor,Pakhtakor,2–1,League
15+
2,14,2012,11 March 2012,Bunyodkor,Pakhtakor,–,Supercup
16+
3,15,2012,26 June 2012,Bunyodkor,Pakhtakor,2–0,League
17+
4,16,2012,9 August 2012,Pakhtakor,Bunyodkor,1–1,League
18+
5,17,2012,22 August 2012,Bunyodkor,Pakhtakor,1–1,Cup
19+
11,18,2012,25 November 2012,Pakhtakor,Bunyodkor,1–3,Cup
20+
12,19,2013,30 June 2013,Pakhtakor,Bunyodkor,0–2,League
21+
7,20,2013,8 August 2013,Bunyodkor,Pakhtakor,1–2,League
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Pilot_ID,Pilot_name,Rank,Age,Nationality,Position,Join_Year,Team
2+
1,Patrick O'Bryant,13,33,United States,Center Team,2009,Bradley
3+
2,Jermaine O'Neal,6,40,United States,Forward-Center Team,2008,Eau Claire High School
4+
3,Dan O'Sullivan,45,37,United States,Center Team,1999,Fordham
5+
4,Charles Oakley,34,22,United Kindom,Forward Team,2001,Virginia Union
6+
5,Hakeem Olajuwon,34,32,Nigeria,Center Team,2010,Houston
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Scientist,Project
2+
123234877,AeH1
3+
152934485,AeH3
4+
222364883,Ast3
5+
326587417,Ast3
6+
332154719,Bte1
7+
546523478,Che1
8+
631231482,Ast3
9+
654873219,Che1
10+
745685214,AeH3
11+
845657245,Ast1
12+
845657246,Ast2
13+
332569843,AeH4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
Code,Name,Hours
2+
AeH1,Winds: Studying Bernoullis Principle,156
3+
AeH2,Aerodynamics and Bridge Design,189
4+
AeH3,Aerodynamics and Gas Mileage,256
5+
AeH4,Aerodynamics and Ice Hockey,789
6+
AeH5,Aerodynamics of a Football,98
7+
AeH6,Aerodynamics of Air Hockey,89
8+
Ast1,A Matter of Time,112
9+
Ast2,A Puzzling Parallax,299
10+
Ast3,Build Your Own Telescope,6546
11+
Bte1,Juicy: Extracting Apple Juice with Pectinase,321
12+
Bte2,A Magnetic Primer Designer,9684
13+
Bte3,Bacterial Transformation Efficiency,321
14+
Che1,A Silver-Cleaning Battery,545
15+
Che2,A Soluble Separation Solution,778
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
[
2+
{
3+
"datasets": [
4+
"scientist_1/Projects",
5+
"scientist_1/AssignedTo"
6+
],
7+
"codes": [
8+
"import pandas as pd\nimport matplotlib.pyplot as plt",
9+
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
10+
],
11+
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
12+
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nplt.bar(project_counts.index, project_counts.values)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
13+
},
14+
{
15+
"datasets": [
16+
"coffee_shop/member"
17+
],
18+
"codes": [
19+
"import pandas as pd\nimport matplotlib.pyplot as plt",
20+
"member = pd.read_csv('../dataset/coffee_shop/member.csv')"
21+
],
22+
"request": "A scatter chart showing the correlation between the age of the customer and the time of purchase colored by membership level.",
23+
"answer": "# Group the dataset by Membership_card\ngroups = member_dataset.groupby('Membership_card')\n\n# Create a scatter chart for each Membership_card\nfor membership_card, group in groups:\n plt.scatter(group['Age'], group['Time_of_purchase'], label=membership_card)\n\n# Set the title and labels\nplt.title('Correlation between Age and Time of Purchase')\nplt.xlabel('Age')\nplt.ylabel('Time of Purchase')\nplt.legend(loc='upper left')\n\n# Show the plot\nplt.show()"
24+
},
25+
{
26+
"datasets": [
27+
"game_injury/game"
28+
],
29+
"codes": [
30+
"import pandas as pd\nimport matplotlib.pyplot as plt",
31+
"game_dataset = pd.read_csv('../dataset/game_injury/game.csv')"
32+
],
33+
"request": "Show the number of games in each season and group by away team in a group line chart. The x-axis is season.",
34+
"answer": "# group the dataset by season and away team\ngrouped = game_dataset.groupby(['Season', 'Away_team']).size().reset_index(name='counts')\n\n# create a pivot table with season as index and away team as columns\npivot_table = pd.pivot_table(grouped, values='counts', index=['Season'], columns=['Away_team'], fill_value=0)\n\n# create the line chart\npivot_table.plot(kind='line')\n\n# set the title and labels\nplt.title('Number of Games in Each Season by Away Team')\nplt.xlabel('Season')\nplt.ylabel('Number of Games')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\n\n# show the plot\nplt.show()"
35+
},
36+
{
37+
"datasets": [
38+
"pilot_record/pilot"
39+
],
40+
"codes": [
41+
"import pandas as pd\nimport matplotlib.pyplot as plt",
42+
"pilot_dataset = pd.read_csv('../dataset/pilot_record/pilot.csv')"
43+
],
44+
"request": "What is the proportion of positions of pilots? Show the result in a pie chart.",
45+
"answer": "# Count the number of each position\nposition_counts = pilot_dataset['Position'].value_counts()\n\n# Create a pie chart\nplt.figure(figsize=(8,6))\nplt.pie(position_counts, labels = position_counts.index, autopct='%1.1f%%')\n\n# Set the title\nplt.title('Proportion of Positions')\n\n# Show the plot\nplt.show()"
46+
}
47+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[
2+
{
3+
"datasets": [
4+
"scientist_1/Projects",
5+
"scientist_1/AssignedTo"
6+
],
7+
"codes": [
8+
"import pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns",
9+
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
10+
],
11+
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
12+
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nsns.barplot(project_counts)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
13+
}
14+
]

0 commit comments

Comments
 (0)