Skip to content

Commit

Permalink
キャンセル時にHeadが戻るように変更
Browse files Browse the repository at this point in the history
  • Loading branch information
kznrluk committed Apr 28, 2024
1 parent 0867e84 commit d1307fe
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions chat/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chat

import (
"context"
"errors"
"fmt"
"github.com/kznrluk/aski/conv"
"github.com/kznrluk/go-anthropic"
Expand Down Expand Up @@ -47,6 +48,9 @@ func (a ap) rest(ctx context.Context, conv conv.Conversation) (string, error) {
)

if err != nil {
if errors.Is(err, context.Canceled) {
return "", ErrCancelled
}
return "", err
}
if len(rest.Content) == 0 {
Expand All @@ -70,6 +74,9 @@ func (a ap) stream(ctx context.Context, conv conv.Conversation) (string, error)
)

if err != nil {
if errors.Is(err, context.Canceled) {
return "", ErrCancelled
}
return "", err
}

Expand All @@ -79,6 +86,8 @@ func (a ap) stream(ctx context.Context, conv conv.Conversation) (string, error)
if err != nil {
if err == io.EOF {
break
} else if errors.Is(err, context.Canceled) {
return "", ErrCancelled
} else {
fmt.Printf("%s", err.Error())
return "", err
Expand Down
5 changes: 5 additions & 0 deletions chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chat

import (
"context"
"errors"
"github.com/kznrluk/aski/config"
"github.com/kznrluk/aski/conv"
"os"
Expand All @@ -18,6 +19,10 @@ type (
}
)

var (
ErrCancelled = errors.New("cancelled")
)

func ProvideChat(model string, cfg config.Config) Chat {
if strings.HasPrefix(model, "claude") {
return NewAnthropic(cfg.AnthropicAPIKey)
Expand Down
9 changes: 9 additions & 0 deletions chat/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chat

import (
"context"
"errors"
"fmt"
"github.com/kznrluk/aski/conv"
"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -61,6 +62,9 @@ func (o oai) rest(ctx context.Context, conv conv.Conversation) (string, error) {
)

if err != nil {
if errors.Is(err, context.Canceled) {
return "", ErrCancelled
}
return "", err
}
fmt.Printf("%s", resp.Choices[0].Message.Content)
Expand Down Expand Up @@ -95,6 +99,9 @@ func (o oai) stream(ctx context.Context, conv conv.Conversation) (string, error)
)

if err != nil {
if errors.Is(err, context.Canceled) {
return "", ErrCancelled
}
return "", err
}

Expand All @@ -104,6 +111,8 @@ func (o oai) stream(ctx context.Context, conv conv.Conversation) (string, error)
if err != nil {
if err == io.EOF {
break
} else if errors.Is(err, context.Canceled) {
return "", ErrCancelled
} else {
return "", err
}
Expand Down
7 changes: 7 additions & 0 deletions conv/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ func (c *conv) ChangeHead(sha1Partial string) (Message, error) {
foundSha := false
foundMessageIndex := -1

if sha1Partial == "ROOT" {
for i := range c.Messages {
c.Messages[i].Head = false
}
return c.convertSystemToMessage(), nil
}

for i, message := range c.Messages {
if strings.HasPrefix(message.Sha1, sha1Partial) {
foundSha = true
Expand Down
5 changes: 5 additions & 0 deletions lib/dialog.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lib

import (
"context"
"errors"
"fmt"
"github.com/fatih/color"
"github.com/kznrluk/aski/chat"
Expand Down Expand Up @@ -100,6 +101,10 @@ func StartDialog(cfg config.Config, cv conv.Conversation, isRestMode bool, resto
fmt.Printf("\n")
data, err := cli.Retrieve(cv, isRestMode)
if err != nil {
if errors.Is(err, chat.ErrCancelled) {
_, _ = cv.ChangeHead(last.ParentSha1)
continue
}
fmt.Printf("\n%s", err.Error())
continue
}
Expand Down

0 comments on commit d1307fe

Please sign in to comment.