Skip to content

feat: add Gemma 3 and better command and LLM structure #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/stale-buckets-roll.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@bashbuddy/cli": patch
---

Added gemma 3, and better command structure
2 changes: 1 addition & 1 deletion apps/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"@trpc/client": "catalog:",
"clipboardy": "^4.0.0",
"commander": "^13.1.0",
"node-llama-cpp": "^3.6.0",
"node-llama-cpp": "^3.7.0",
"superjson": "catalog:",
"yaml": "^2.7.0",
"zod": "catalog:"
Expand Down
240 changes: 182 additions & 58 deletions apps/cli/src/commands/ask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import chalk from "chalk";
import clipboardy from "clipboardy";
import { Command } from "commander";

import type { LLMResponse } from "@bashbuddy/validators";
import { processPrompt } from "@bashbuddy/agent";
import type { LLMMessage } from "@bashbuddy/agent";
import type { LLMContext, LLMResponse } from "@bashbuddy/validators";
import { processPrompt, yamlPrompt } from "@bashbuddy/agent";
import { SITE_URLS } from "@bashbuddy/consts";

import { LocalLLM } from "../llms/localllm";
Expand All @@ -22,8 +23,14 @@ import { runCommandWithStream } from "../utils/runner";
export function createAskCommand(): Command {
const askCommand = new Command("ask")
.description("Ask a question to the AI")
.argument("<question...>", "The question to ask the AI")
.action((questionParts: string[]) => {
.argument("[question...]", "The question to ask the AI")
.action((questionParts: string[] = []) => {
// If no question parts, prompt the user
if (questionParts.length === 0) {
promptForQuestion().catch(console.error);
return;
}

// Join all parts of the question with spaces
const question = questionParts.join(" ");
execute(question).catch(console.error);
Expand All @@ -32,6 +39,34 @@ export function createAskCommand(): Command {
return askCommand;
}

/**
* Prompt the user for a question if none was provided
*/
async function promptForQuestion() {
p.intro("BashBuddy");

const question = await p.text({
message: "What would you like to ask?",
placeholder: "Ask for a command",
});

if (p.isCancel(question) || !question) {
p.cancel("Operation cancelled");
return;
}

await execute(question);
}

interface ConversationState {
messages: LLMMessage[];
context: LLMContext;
chatId: string;
llm?: LocalLLM;
isCloudMode: boolean;
revisionCount: number;
}

async function execute(question: string) {
p.intro("BashBuddy");

Expand All @@ -42,6 +77,7 @@ async function execute(question: string) {
]);

let commandToRun: string | undefined;
let conversationState: ConversationState;

switch (mode) {
case LOCAL_MODE: {
Expand All @@ -61,15 +97,26 @@ async function execute(question: string) {
await llm.init();
modelSpinner.stop("Model loaded!");

const createNewOutputStream = (newUserInput: string) =>
Promise.resolve(processPrompt(llm, context, newUserInput, true));

commandToRun = await cliInfer(
await createNewOutputStream(question),
createNewOutputStream,
1,
false,
);
conversationState = {
messages: [
{
role: "system",
content: yamlPrompt(context),
},
{
role: "user",
content: question,
},
],
context,
chatId: "local",
llm,
isCloudMode: false,
revisionCount: 1,
};

const stream = processPrompt(llm, conversationState.messages);
commandToRun = await handleInference(stream, conversationState);

await llm.dispose();

Expand All @@ -79,20 +126,27 @@ async function execute(question: string) {
try {
const chatId = await trpc.chat.createChat.mutate();

const createNewOutputStream = (newUserInput: string) =>
trpc.chat.ask.mutate({
input: newUserInput,
context,
chatId,
useYaml: true,
});

commandToRun = await cliInfer(
await createNewOutputStream(question),
createNewOutputStream,
1,
true,
);
conversationState = {
messages: [
{
role: "user",
content: question,
},
],
context,
chatId,
isCloudMode: true,
revisionCount: 1,
};

const stream = await trpc.chat.ask.mutate({
input: question,
context,
chatId,
useYaml: true,
});

commandToRun = await handleInference(stream, conversationState);
} catch (err) {
if (err instanceof TRPCClientError) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
Expand Down Expand Up @@ -140,24 +194,32 @@ async function execute(question: string) {
}
}

async function cliInfer(
/**
* Process LLM inference and return the parsed response
*/
async function processInference(
outputStream: AsyncIterable<string>,
createNewOutputStream: (
newUserInput: string,
) => Promise<AsyncIterable<string>>,
revisionCount = 1,
isCloudMode = false,
): Promise<string | undefined> {
state: ConversationState,
): Promise<LLMResponse | undefined> {
const llmSpinner = p.spinner();
llmSpinner.start("Processing...");

let finalResponse: LLMResponse;

try {
finalResponse = await parseYamlResponse(outputStream, (response) => {
if (response.command) {
llmSpinner.message(response.command);
}
const { parsed, raw } = await parseYamlResponse(
outputStream,
(response) => {
if (response.command) {
llmSpinner.message(response.command);
}
},
);

finalResponse = parsed;
state.messages.push({
role: "model",
content: raw,
});
} catch (err) {
if (err instanceof ResponseParseError) {
Expand All @@ -171,25 +233,70 @@ async function cliInfer(
}

llmSpinner.stop(finalResponse.command);
return finalResponse;
}

if (finalResponse.wrong) {
/**
* Display command information to the user
*/
function displayCommandInfo(response: LLMResponse): void {
if (response.wrong) {
p.log.message(chalk.red("Please, limit yourself to ask for commands. "));

return;
}

if (finalResponse.explanation) {
p.log.message(chalk.dim(`Explanation: ${finalResponse.explanation}`));
if (response.explanation) {
p.log.message(chalk.dim(`Explanation: ${response.explanation}`));
}

if (finalResponse.dangerous) {
if (response.dangerous) {
p.log.message(
chalk.red(
`⚠️ Be careful, buddy has marked this command as dangerous. Make sure to know what it does.`,
),
);
}
}

/**
* Generate a new inference stream based on user suggestion
*/
async function generateNewStream(
suggestion: string,
state: ConversationState,
): Promise<AsyncIterable<string>> {
// Add the suggestion to the messages
state.messages.push({
role: "user",
content: suggestion,
});

// Increment revision count
state.revisionCount += 1;

// Generate a new stream based on mode
if (state.isCloudMode) {
return trpc.chat.ask.mutate({
input: suggestion,
context: state.context,
chatId: state.chatId,
useYaml: true,
});
} else {
if (!state.llm) {
throw new Error("LLM not initialized");
}
return processPrompt(state.llm, state.messages);
}
}

/**
* Handle user action on the command
*/
async function handleUserAction(
response: LLMResponse,
state: ConversationState,
): Promise<string | undefined> {
// Options for the select component
const options = [
{ value: "copyAndRun", label: "Copy & Run" },
Expand All @@ -198,9 +305,9 @@ async function cliInfer(
];

// Only add the suggest option if we haven't reached the revision limit in cloud mode
if (!isCloudMode || revisionCount < 5) {
if (!state.isCloudMode || state.revisionCount < 5) {
options.push({ value: "suggest", label: "Suggest changes" });
} else if (revisionCount >= 5) {
} else if (state.revisionCount >= 5) {
p.log.message(
chalk.yellow("You've reached the maximum of 5 revisions in cloud mode."),
);
Expand All @@ -221,19 +328,19 @@ async function cliInfer(

switch (action) {
case "run":
return finalResponse.command;
return response.command;
case "copy": {
// Copy the command to clipboard
try {
await clipboardy.write(finalResponse.command);
await clipboardy.write(response.command);
p.log.success("Command copied to clipboard");
} catch {
p.log.error("Failed to copy command to clipboard");
}

p.log.message(
chalk.dim(
`Feel free to paste the command into your terminal: ${finalResponse.command}`,
`Feel free to paste the command into your terminal: ${response.command}`,
),
);

Expand All @@ -242,18 +349,18 @@ async function cliInfer(
case "copyAndRun": {
// Copy the command to clipboard and run it
try {
await clipboardy.write(finalResponse.command);
await clipboardy.write(response.command);
p.log.success("Command copied to clipboard");
} catch {
p.log.error(
`Failed to copy command to clipboard, but will still run. Feel free to copy it: ${finalResponse.command}`,
`Failed to copy command to clipboard, but will still run. Feel free to copy it: ${response.command}`,
);
}

return finalResponse.command;
return response.command;
}
case "suggest": {
// Allow user to suggest changes (original behavior when typing)
// Allow user to suggest changes
const suggestion = await p.text({
message: "What changes would you like to suggest?",
placeholder: "Type your suggestion here",
Expand All @@ -265,16 +372,33 @@ async function cliInfer(
}

if (suggestion) {
return cliInfer(
await createNewOutputStream(suggestion),
createNewOutputStream,
revisionCount + 1,
isCloudMode,
);
const newStream = await generateNewStream(suggestion, state);
return handleInference(newStream, state);
}
return undefined;
}
default:
return undefined;
}
}

/**
* Handle the entire inference process
*/
async function handleInference(
outputStream: AsyncIterable<string>,
state: ConversationState,
): Promise<string | undefined> {
// Process the inference
const finalResponse = await processInference(outputStream, state);

if (!finalResponse) {
return undefined;
}

// Display command information
displayCommandInfo(finalResponse);

// Handle user action
return handleUserAction(finalResponse, state);
}
Loading