Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi doc support prompt studio #729

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions backend/file_management/file_management_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
from unstract.connectors.filesystems import connectors as fs_connectors
from unstract.connectors.filesystems.unstract_file_system import UnstractFileSystem

try:
from plugins.processor.file_converter.constants import (
ExtentedFileInformationKey as FileKey,
)
except ImportError:
from file_management.constants import FileInformationKey as FileKey
jagadeeswaran-zipstack marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -138,20 +145,29 @@ def upload_file(
# adding filename with path
file_path += file_name
with fs.open(file_path, mode="wb") as remote_file:
remote_file.write(file.read())
if isinstance(file, bytes):
remote_file.write(file)
else:
remote_file.write(file.read())

@staticmethod
def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
fs = file_system.get_fsspec_fs()

# Define allowed content types
allowed_content_types = FileKey.FILE_UPLOAD_ALLOWED_MIME

try:
file_info = fs.info(file_path)
except FileNotFoundError:
raise FileNotFound

file_content_type = file_info.get("ContentType")
file_type = file_info.get("type")

if file_type != "file":
raise InvalidFileType

try:
if not file_content_type:
file_content_type, _ = mimetypes.guess_type(file_path)
Expand All @@ -163,19 +179,26 @@ def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
except ApiRequestError as exception:
logger.error(f"ApiRequestError from {file_info} {exception}")
raise ConnectorApiRequestError

data = ""
# Check if the file type is in the allowed list
if file_content_type not in allowed_content_types:
raise InvalidFileType(f"File type '{file_content_type}' is not allowed.")

# Handle allowed file types
if file_content_type == "application/pdf":
# Read contents of PDF file into a string
with fs.open(file_path, "rb") as file:
encoded_string = base64.b64encode(file.read())
return encoded_string
data = base64.b64encode(file.read())

elif file_content_type == "text/plain":
with fs.open(file_path, "r") as file:
logger.info(f"Reading text file: {file_path}")
text_content = file.read()
return text_content
data = file.read()

else:
raise InvalidFileType
raise InvalidFileType(f"File type '{file_content_type}' is not handled.")
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved

return {"data": data, "mime_type": file_content_type}

@staticmethod
def _delete_file(fs, file_path):
Expand Down
22 changes: 15 additions & 7 deletions backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,43 +393,51 @@ def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response:
file_name: str = document.document_name
view_type: str = serializer.validated_data.get("view_type")

# Extract filename without extension
filename_without_extension = file_name.rsplit(".", 1)[0]

# Handle view_type logic, always converting to .txt for EXTRACT and SUMMARIZE
if view_type == FileViewTypes.EXTRACT:
file_name = (
f"{FileViewTypes.EXTRACT.lower()}/" f"{filename_without_extension}.txt"
)
if view_type == FileViewTypes.SUMMARIZE:
elif view_type == FileViewTypes.SUMMARIZE:
file_name = (
f"{FileViewTypes.SUMMARIZE.lower()}/"
f"{filename_without_extension}.txt"
)

file_path = file_path = FileManagerHelper.handle_sub_directory_for_tenants(
file_path = FileManagerHelper.handle_sub_directory_for_tenants(
UserSessionUtils.get_organization_id(request),
is_create=True,
user_id=custom_tool.created_by.user_id,
tool_id=str(custom_tool.tool_id),
)
file_system = LocalStorageFS(settings={"path": file_path})

# Ensure file path formatting
if not file_path.endswith("/"):
file_path += "/"
file_path += file_name
# Temporary Hack for frictionless onboarding as the user id will be empty
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved

file_system = LocalStorageFS(settings={"path": file_path})

# Handle file content retrieval
try:
contents = FileManagerHelper.fetch_file_contents(file_system, file_path)
except FileNotFound:
file_path = file_path = FileManagerHelper.handle_sub_directory_for_tenants(
# Retry with empty user_id
file_path = FileManagerHelper.handle_sub_directory_for_tenants(
UserSessionUtils.get_organization_id(request),
is_create=True,
user_id="",
tool_id=str(custom_tool.tool_id),
)
if not file_path.endswith("/"):
file_path += "/"
file_path += file_name
file_path += file_name
contents = FileManagerHelper.fetch_file_contents(file_system, file_path)

return Response({"data": contents}, status=status.HTTP_200_OK)
return Response(contents, status=status.HTTP_200_OK)

@action(detail=True, methods=["post"])
def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
Expand Down
14 changes: 10 additions & 4 deletions backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from account_v2.models import User
from account_v2.serializer import UserSerializer
from django.core.exceptions import ObjectDoesNotExist
from file_management.constants import FileInformationKey
from prompt_studio.prompt_profile_manager_v2.models import ProfileManager
from prompt_studio.prompt_studio_core_v2.constants import ToolStudioKeys as TSKeys
from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError
Expand All @@ -20,6 +19,13 @@

logger = logging.getLogger(__name__)

try:
from plugins.processor.file_converter.constants import (
ExtentedFileInformationKey as FileKey,
)
except ImportError:
from file_management.constants import FileInformationKey as FileKey


class CustomToolSerializer(IntegrityErrorMixin, AuditSerializer):
shared_users = serializers.PrimaryKeyRelatedField(
Expand Down Expand Up @@ -117,10 +123,10 @@ class FileUploadIdeSerializer(serializers.Serializer):
required=True,
validators=[
FileValidator(
allowed_extensions=FileInformationKey.FILE_UPLOAD_ALLOWED_EXT,
allowed_mimetypes=FileInformationKey.FILE_UPLOAD_ALLOWED_MIME,
allowed_extensions=FileKey.FILE_UPLOAD_ALLOWED_EXT,
allowed_mimetypes=FileKey.FILE_UPLOAD_ALLOWED_MIME,
min_size=0,
max_size=FileInformationKey.FILE_UPLOAD_MAX_SIZE,
max_size=FileKey.FILE_UPLOAD_MAX_SIZE,
)
],
)
16 changes: 15 additions & 1 deletion backend/prompt_studio/prompt_studio_core_v2/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import logging
import os
import uuid
from typing import Any, Optional

Expand Down Expand Up @@ -438,6 +440,10 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
serializer = FileUploadIdeSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
uploaded_files: Any = serializer.validated_data.get("file")
file_converter = get_plugin_class_by_name(
name="file_converter",
plugins=self.processor_plugins,
)

file_path = FileManagerHelper.handle_sub_directory_for_tenants(
UserSessionUtils.get_organization_id(request),
Expand All @@ -450,6 +456,14 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
documents = []
for uploaded_file in uploaded_files:
file_name = uploaded_file.name
file_data = uploaded_file
file_type = uploaded_file.content_type
# Convert non-PDF files
if file_converter and file_type != "application/pdf":
file_data_bytes = uploaded_file.read()
with io.BytesIO(file_data_bytes) as file_stream:
file_data = file_converter.convert_to_pdf(file_stream, file_name)
file_name = f"{os.path.splitext(file_name)[0]}.pdf"
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved

# Create a record in the db for the file
document = PromptStudioDocumentHelper.create(
Expand All @@ -468,7 +482,7 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
FileManagerHelper.upload_file(
file_system,
file_path,
uploaded_file,
file_data,
file_name,
)
documents.append(doc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { useEffect, useState } from "react";
import { useParams } from "react-router-dom";
import "./DocumentManager.css";

import { base64toBlob, docIndexStatus } from "../../../helpers/GetStaticData";
import { docIndexStatus } from "../../../helpers/GetStaticData";
import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
import { useCustomToolStore } from "../../../store/custom-tool-store";
import { useSessionStore } from "../../../store/session-store";
Expand All @@ -22,6 +22,7 @@ import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal";
import { PdfViewer } from "../pdf-viewer/PdfViewer";
import { TextViewerPre } from "../text-viewer-pre/TextViewerPre";
import usePostHogEvents from "../../../hooks/usePostHogEvents";
import { TextViewer } from "../text-viewer/TextViewer";

let items = [
{
Expand All @@ -39,6 +40,22 @@ const viewTypes = {
extract: "EXTRACT",
};

const base64toBlob = (data, mimeType) => {
const byteCharacters = atob(data?.data); // Decode base64
const byteArrays = [];

for (let offset = 0; offset < byteCharacters.length; offset += 512) {
const slice = byteCharacters.slice(offset, offset + 512);
const byteNumbers = new Array(slice.length);
for (let i = 0; i < slice.length; i++) {
byteNumbers[i] = slice.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
byteArrays.push(byteArray);
}
return new Blob(byteArrays, { type: mimeType });
};

vishnuszipstack marked this conversation as resolved.
Show resolved Hide resolved
// Import components for the summarize feature
let SummarizeView = null;
try {
Expand Down Expand Up @@ -101,6 +118,20 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
const { setPostHogCustomEvent } = usePostHogEvents();
const { id } = useParams();

const [blobFileUrl, setBlobFileUrl] = useState("");
const [fileData, setFileData] = useState({});

useEffect(() => {
// Convert blob URL to an object URL
if (fileData.blob) {
const objectUrl = URL.createObjectURL(fileData.blob);
setBlobFileUrl(objectUrl);

// Clean up the URL after component unmount
return () => URL.revokeObjectURL(objectUrl);
}
}, [fileData]);

useEffect(() => {
if (isSimplePromptStudio) {
items = [
Expand Down Expand Up @@ -197,7 +228,8 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
getDocsFunc(details?.tool_id, selectedDoc?.document_id, viewType)
.then((res) => {
const data = res?.data?.data || "";
processGetDocsResponse(data, viewType);
const mimeType = res?.data?.mime_type || "";
processGetDocsResponse(data, viewType, mimeType);
})
.catch((err) => {
handleGetDocsError(err, viewType);
Expand All @@ -224,11 +256,19 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
});
};

const processGetDocsResponse = (data, viewType) => {
const processGetDocsResponse = (data, viewType, mimeType) => {
if (viewType === viewTypes.original) {
const base64String = data || "";
const blob = base64toBlob(base64String);
setFileUrl(URL.createObjectURL(blob));
const blob = base64toBlob(base64String, mimeType);
setFileData({ blob, mimeType });
const reader = new FileReader();
reader.readAsDataURL(blob);
reader.onload = () => {
setFileUrl(reader.result);
};
reader.onerror = () => {
throw new Error("Fail to load the file");
};
} else if (viewType === viewTypes.extract) {
setExtractTxt(data);
}
Expand Down Expand Up @@ -315,6 +355,20 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
}
};

const renderDoc = (docName, fileUrl) => {
const fileType = docName?.split(".").pop().toLowerCase(); // Get the file extension
console.log(docName, fileData);
vishnuszipstack marked this conversation as resolved.
Show resolved Hide resolved
switch (fileType) {
case "pdf":
return <PdfViewer fileUrl={fileUrl} />;
case "txt":
case "md":
return <TextViewer fileUrl={fileUrl} />;
default:
return <div>Unsupported file type: {fileType}</div>;
}
};

return (
<div className="doc-manager-layout">
<div className="doc-manager-header">
Expand Down Expand Up @@ -386,7 +440,8 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
setOpenManageDocsModal={setOpenManageDocsModal}
errMsg={fileErrMsg}
>
<PdfViewer fileUrl={fileUrl} />
{console.log(fileData.blob)}
vishnuszipstack marked this conversation as resolved.
Show resolved Hide resolved
{renderDoc(selectedDoc?.document_name, blobFileUrl)}
</DocumentViewer>
)}
{activeKey === "2" && (
Expand Down
Loading