-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fine-tuning and post train chat eval POC
- Demonstrates an end to end knowledge submission, generate, train and post-train side-by-side model comparison for the user to validate their knowledge submission is included in the newly trained checkpoint. - For this to function for the frontend to make REST API calls to Instructlab this uses an api-server that frontends ilab. The code is here https://github.com/nerdalert/ilab-api-server - The demo was run on a 24GB GPU leveraging the simple pipeline. Will get an example acceslerated pipeline demo with some hardware soon. - Training and generation for the demo took around ~30-45m or so. - All functionality is decoupled from the system via REST making it serviceable out of the gate and enabling the UI functionality. Signed-off-by: Brent Salisbury <[email protected]>
- Loading branch information
Showing
19 changed files
with
3,998 additions
and
33 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// src/pages/api/fine-tune/data-sets.ts | ||
'use server'; | ||
|
||
import { NextRequest, NextResponse } from 'next/server'; | ||
|
||
export async function GET(req: NextRequest) { | ||
try { | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
const response = await fetch(`${API_SERVER}/data`); | ||
const data = await response.json(); | ||
|
||
if (!response.ok) { | ||
return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: response.status }); | ||
} | ||
|
||
return NextResponse.json(data, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error fetching datasets:', error); | ||
return NextResponse.json({ error: 'Internal Server Error' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function POST(request: Request) { | ||
try { | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
const response = await fetch(`${API_SERVER}/data/generate`, { | ||
method: 'POST' | ||
}); | ||
|
||
if (!response.ok) { | ||
console.error('Error response from API server:', response.status, response.statusText); | ||
return NextResponse.json({ error: 'Failed to generate data' }, { status: response.status }); | ||
} | ||
|
||
const responseData = await response.json(); | ||
|
||
// Return the response from the API server to the client | ||
return NextResponse.json(responseData, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error generating data:', error); | ||
return NextResponse.json({ error: 'An error occurred while generating data' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// src/app/api/fine-tune/jobs/[job_id]/logs/route.ts | ||
import { NextResponse } from 'next/server'; | ||
|
||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
export async function GET(request: Request, { params }: { params: { job_id: string } }) { | ||
const { job_id } = await Promise.resolve(params); | ||
|
||
try { | ||
const response = await fetch(`${API_SERVER}/jobs/${job_id}/logs`, { | ||
method: 'GET' | ||
}); | ||
|
||
if (!response.ok) { | ||
const errorText = await response.text(); | ||
console.error('Error from API server:', errorText); | ||
return NextResponse.json({ error: 'Error fetching logs' }, { status: 500 }); | ||
} | ||
|
||
const logs = await response.text(); | ||
return new NextResponse(logs, { | ||
status: 200, | ||
headers: { 'Content-Type': 'text/plain' } | ||
}); | ||
} catch (error) { | ||
console.error('Error fetching logs:', error); | ||
return NextResponse.json({ error: 'Error fetching logs' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// src/app/api/fine-tune/jobs/[job_id]/status/route.ts | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function GET(request: Request, { params }: { params: { job_id: string } }) { | ||
const { job_id } = params; | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
try { | ||
// Forward the request to the API server | ||
const response = await fetch(`${API_SERVER}/jobs/${job_id}/status`, { | ||
method: 'GET' | ||
}); | ||
|
||
if (!response.ok) { | ||
const errorText = await response.text(); | ||
console.error('Error from API server:', errorText); | ||
return NextResponse.json({ error: 'Error fetching job status' }, { status: 500 }); | ||
} | ||
|
||
const result = await response.json(); | ||
// Return the job status to the client | ||
return NextResponse.json(result, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error fetching job status:', error); | ||
return NextResponse.json({ error: 'Error fetching job status' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// src/app/api/fine-tune/jobs/route.ts | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function GET(request: Request) { | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
try { | ||
const response = await fetch(`${API_SERVER}/jobs`); | ||
if (!response.ok) { | ||
const errorText = await response.text(); | ||
console.error('Error from API server:', errorText); | ||
return NextResponse.json({ error: 'Error fetching jobs' }, { status: 500 }); | ||
} | ||
const result = await response.json(); | ||
return NextResponse.json(result, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error fetching jobs:', error); | ||
return NextResponse.json({ error: 'Error fetching jobs' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// src/app/api/model/serve-base/route.ts | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function POST() { | ||
try { | ||
console.log('Received serve-base model request'); | ||
|
||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
const endpoint = `${API_SERVER}/model/serve-base`; | ||
|
||
console.log(`Forwarding request to the API server: ${endpoint}`); | ||
|
||
// No request body needed for serving the base model | ||
const response = await fetch(endpoint, { | ||
method: 'POST', | ||
headers: { 'Content-Type': 'application/json' } | ||
}); | ||
|
||
console.log('Response from API server (serve-base):', { | ||
status: response.status, | ||
statusText: response.statusText | ||
}); | ||
|
||
if (!response.ok) { | ||
console.error('Error response from the API server:', response.status, response.statusText); | ||
return NextResponse.json({ error: 'Failed to serve the base model on the API server' }, { status: response.status }); | ||
} | ||
|
||
// Parse response safely | ||
let responseData; | ||
try { | ||
const text = await response.text(); | ||
responseData = text ? JSON.parse(text) : {}; | ||
console.log('Parsed response data (serve-base):', responseData); | ||
} catch (error) { | ||
console.error('Error parsing JSON response from API server:', error); | ||
return NextResponse.json({ error: 'Invalid JSON response from the API server' }, { status: 500 }); | ||
} | ||
|
||
if (!responseData.job_id) { | ||
console.error('Missing job_id in API server response for serve-base:', responseData); | ||
return NextResponse.json({ error: 'API server response does not contain job_id' }, { status: 500 }); | ||
} | ||
|
||
// Return the response from the API server to the client | ||
console.log('Returning success response with job_id (serve-base):', responseData.job_id); | ||
return NextResponse.json(responseData, { status: 200 }); | ||
} catch (error) { | ||
console.error('Unexpected error during serve-base:', error); | ||
return NextResponse.json({ error: 'An unexpected error occurred during serving the base model' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// src/app/api/model/serve-latest/route.ts | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function POST() { | ||
try { | ||
console.log('Received serve-latest model request'); | ||
|
||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
const endpoint = `${API_SERVER}/model/serve-latest`; | ||
|
||
console.log(`Forwarding request to API server: ${endpoint}`); | ||
|
||
// No request body needed for serving the latest model | ||
const response = await fetch(endpoint, { | ||
method: 'POST', | ||
headers: { 'Content-Type': 'application/json' } | ||
}); | ||
|
||
console.log('Response from API server:', { | ||
status: response.status, | ||
statusText: response.statusText | ||
}); | ||
|
||
if (!response.ok) { | ||
console.error('Error response from API server:', response.status, response.statusText); | ||
return NextResponse.json({ error: 'Failed to serve the latest model on the API server' }, { status: response.status }); | ||
} | ||
|
||
// Parse response safely | ||
let responseData; | ||
try { | ||
const text = await response.text(); | ||
responseData = text ? JSON.parse(text) : {}; | ||
console.log('Parsed response data (serve-latest):', responseData); | ||
} catch (error) { | ||
console.error('Error parsing JSON response from API server:', error); | ||
return NextResponse.json({ error: 'Invalid JSON response from the API server' }, { status: 500 }); | ||
} | ||
|
||
if (!responseData.job_id) { | ||
console.error('Missing job_id in API server response for serve-latest:', responseData); | ||
return NextResponse.json({ error: 'API server response does not contain job_id' }, { status: 500 }); | ||
} | ||
|
||
// Return the response from the API server to the client | ||
console.log('Returning success response with job_id (serve-latest):', responseData.job_id); | ||
return NextResponse.json(responseData, { status: 200 }); | ||
} catch (error) { | ||
console.error('Unexpected error during serve-latest:', error); | ||
return NextResponse.json({ error: 'An unexpected error occurred during serving the latest model' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// src/app/api/fine-tune/model/train | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function POST(request: Request) { | ||
try { | ||
console.log('Received train job request'); | ||
|
||
// Parse the request body for required data | ||
const { modelName, branchName } = await request.json(); | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
console.log('Request body:', { modelName, branchName }); | ||
|
||
if (!modelName || !branchName) { | ||
console.error('Missing required parameters: modelName and branchName'); | ||
return NextResponse.json({ error: 'Missing required parameters: modelName and branchName' }, { status: 400 }); | ||
} | ||
|
||
// Forward the request to the API server | ||
const endpoint = `${API_SERVER}/model/train`; | ||
|
||
console.log(`Forwarding request to API server: ${API_SERVER}`); | ||
|
||
const response = await fetch(endpoint, { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ | ||
modelName, | ||
branchName | ||
}) | ||
}); | ||
|
||
console.log('Response from API server:', { | ||
status: response.status, | ||
statusText: response.statusText | ||
}); | ||
|
||
if (!response.ok) { | ||
console.error('Error response from API server:', response.status, response.statusText); | ||
return NextResponse.json({ error: 'Failed to train the model on the API server' }, { status: response.status }); | ||
} | ||
|
||
// Parse response safely | ||
let responseData; | ||
try { | ||
const text = await response.text(); | ||
responseData = text ? JSON.parse(text) : {}; | ||
console.log('Parsed response data:', responseData); | ||
} catch (error) { | ||
console.error('Error parsing JSON response from API server:', error); | ||
return NextResponse.json({ error: 'Invalid JSON response from the API server' }, { status: 500 }); | ||
} | ||
|
||
if (!responseData.job_id) { | ||
console.error('Missing job_id in API server response:', responseData); | ||
return NextResponse.json({ error: 'API server response does not contain job_id' }, { status: 500 }); | ||
} | ||
|
||
// Return the response from the API server to the client | ||
console.log('Returning success response with job_id:', responseData.job_id); | ||
return NextResponse.json(responseData, { status: 200 }); | ||
} catch (error) { | ||
console.error('Unexpected error during training:', error); | ||
return NextResponse.json({ error: 'An unexpected error occurred during training' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// src/pages/api/fine-tune/models/route.ts | ||
'use server'; | ||
|
||
import { NextRequest, NextResponse } from 'next/server'; | ||
|
||
export async function GET(req: NextRequest) { | ||
try { | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
const response = await fetch(`${API_SERVER}/models`); | ||
const data = await response.json(); | ||
|
||
if (!response.ok) { | ||
return NextResponse.json({ error: 'Failed to fetch models' }, { status: response.status }); | ||
} | ||
|
||
return NextResponse.json(data, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error fetching models:', error); | ||
return NextResponse.json({ error: 'Internal Server Error' }, { status: 500 }); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
'use server'; | ||
|
||
import { NextResponse } from 'next/server'; | ||
|
||
export async function POST(request: Request) { | ||
try { | ||
// Parse the request body for required data | ||
const { modelName, branchName } = await request.json(); | ||
const API_SERVER = process.env.NEXT_PUBLIC_API_SERVER!; | ||
|
||
if (!modelName || !branchName) { | ||
return NextResponse.json({ error: 'Missing required parameters: modelName and branchName' }, { status: 400 }); | ||
} | ||
|
||
// Forward the request to the API server's pipeline endpoint | ||
const endpoint = `${API_SERVER}/pipeline/generate-train`; | ||
|
||
const response = await fetch(endpoint, { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ | ||
modelName, | ||
branchName | ||
}) | ||
}); | ||
|
||
if (!response.ok) { | ||
console.error('Error response from API server (pipeline):', response.status, response.statusText); | ||
return NextResponse.json({ error: 'Failed to run generate-train pipeline on the API server' }, { status: response.status }); | ||
} | ||
|
||
const responseData = await response.json(); | ||
return NextResponse.json(responseData, { status: 200 }); | ||
} catch (error) { | ||
console.error('Error during generate-train pipeline:', error); | ||
return NextResponse.json({ error: 'An error occurred during generate-train pipeline' }, { status: 500 }); | ||
} | ||
} |
Oops, something went wrong.