diff options
| author | Furkan Sahin <furkan-dev@proton.me> | 2023-05-05 14:48:37 -0500 |
|---|---|---|
| committer | Furkan Sahin <furkan-dev@proton.me> | 2023-05-05 14:48:37 -0500 |
| commit | ca65e6384861b5446bec2422d9ebf0137da29d37 (patch) | |
| tree | 002b0742043022124ff7ed2f58dbe03fec4c7f71 /src/gpt_chat_cli/openai_wrappers.py | |
| parent | a472f463e0d7e48f4b299cc9163bb0c05eaa8585 (diff) | |
Add interactive mode
Diffstat (limited to 'src/gpt_chat_cli/openai_wrappers.py')
| -rw-r--r-- | src/gpt_chat_cli/openai_wrappers.py | 44 |
1 files changed, 40 insertions, 4 deletions
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 413ec24..6eeba4d 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -28,6 +28,24 @@ class Choice: finish_reason: Optional[FinishReason] index: int +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + +@dataclass +class ChatMessage: + role: Role + content: str + + def to_json(self : "ChatMessage"): + return { + "role": self.role.value, + "content": self.content + } + +ChatHistory = List[ChatMessage] + @dataclass class OpenAIChatResponse: choices: List[Choice] @@ -61,13 +79,31 @@ class OpenAIChatResponse: OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -def create_chat_completion(*args, **kwargs) \ - -> OpenAIChatResponseStream: +from .argparsing import CompletionArguments + +def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \ + -> OpenAIChatResponseStream: + + messages = [ msg.to_json() for msg in hist ] + + response = openai.ChatCompletion.create( + model=args.model, + messages=messages, + n=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + stream=True + ) + return ( - OpenAIChatResponse.from_json(update) \ - for update in openai.ChatCompletion.create(*args, **kwargs) + OpenAIChatResponse.from_json( update ) \ + for update in response ) + def list_models() -> List[str]: model_data = openai.Model.list() |
