Skip to content

Commit d942c96

Browse files
committed
Adding general guidance ai model support
Signed-off-by: Jannik hartmann <[email protected]>
1 parent 6476b90 commit d942c96

File tree

5 files changed

+25
-4
lines changed

5 files changed

+25
-4
lines changed

pywhyllm/suggesters/identification_suggester.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ class IdentificationSuggester(IdentifierProtocol):
1111

1212
def __init__(self, llm=None):
1313
if llm is not None:
14-
if (llm == 'gpt-4'):
14+
if llm == 'gpt-4':
1515
self.llm = guidance.models.OpenAI('gpt-4')
1616
self.model_suggester = ModelSuggester('gpt-4')
17+
elif isinstance(llm, guidance.models.Model):
18+
self.llm = llm
19+
self.model_suggester = ModelSuggester(llm)
20+
else:
21+
raise ValueError("llm must be either 'gpt-4' or a guidance model instance.")
1722

1823
# def suggest_estimand(
1924
# self,

pywhyllm/suggesters/model_suggester.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ class ModelSuggester(ModelerProtocol):
1212

1313
def __init__(self, llm=None):
1414
if llm is not None:
15-
if (llm == 'gpt-4'):
15+
if llm == 'gpt-4':
1616
self.llm = guidance.models.OpenAI('gpt-4')
17+
elif isinstance(llm, guidance.models.Model):
18+
self.llm = llm
19+
else:
20+
raise ValueError("llm must be either 'gpt-4' or a guidance model instance.")
1721

1822
def suggest_domain_expertises(
1923
self,

pywhyllm/suggesters/simple_identification_suggester.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ class SimpleIdentificationSuggester:
77

88
def __init__(self, llm=None):
99
if llm is not None:
10-
if (llm == 'gpt-4'):
10+
if llm == 'gpt-4':
1111
self.llm = guidance.models.OpenAI('gpt-4')
12+
elif isinstance(llm, guidance.models.Model):
13+
self.llm = llm
14+
else:
15+
raise ValueError("llm must be either 'gpt-4' or a guidance model instance.")
1216

1317
def suggest_iv(self, factors, treatment, outcome):
1418
lm = self.llm

pywhyllm/suggesters/simple_model_suggester.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ class SimpleModelSuggester:
2222

2323
def __init__(self, llm=None):
2424
if llm is not None:
25-
if (llm == 'gpt-4'):
25+
if llm == 'gpt-4':
2626
self.llm = guidance.models.OpenAI('gpt-4')
27+
elif isinstance(llm, guidance.models.Model):
28+
self.llm = llm
29+
else:
30+
raise ValueError("llm must be either 'gpt-4' or a guidance model instance.")
2731

2832
# new ver
2933
def suggest_pairwise_relationship(self, variable1: str, variable2: str):

pywhyllm/suggesters/validation_suggester.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def __init__(self, llm=None):
1818
if llm is not None:
1919
if llm == 'gpt-4':
2020
self.llm = guidance.models.OpenAI('gpt-4')
21+
elif isinstance(llm, guidance.models.Model):
22+
self.llm = llm
23+
else:
24+
raise ValueError("llm must be either 'gpt-4' or a guidance model instance.")
2125

2226
def suggest_negative_controls(
2327
self,

0 commit comments

Comments
 (0)