Coder Social home page Coder Social logo

automata's People

Contributors

codesee-maps[bot] avatar emrgnt-cmplxty avatar maks-ivanov avatar

Stargazers

 avatar

Watchers

 avatar

Forkers

apollohuang1

automata's Issues

Update README

The overview section of the README is out of date.
Delete the existing overview section, then review the repository file by file and write a new readme to summarize the code in the project.

Add error handling for invalid user input in main.py

Currently, main.py asks the user to input data via input() calls. If the user writes something unexpected, the whole script will fail.
Add exception handling blocks around input so when user input is bad, it prompts to try again.

Test coverage gap: core.agent.automata_agent - _parse_completion_message

Write a test to satisfy the following coverage gap:
Module: core.agent.automata_agent
Function: _parse_completion_message
Uncovered lines: [311, 312, 313, 314, 315, 316, 317, 319, 320, 323, 405, 406, 407, 408, 409, 410, 411, 412, 414, 415, 418]
Code:

66:     """
67:     AutomataAgent is an autonomous agent that performs the actual work of the Automata
68:     system. Automata are responsible for executing instructions and reporting
69:     the results back to the master.
70:     """
71: 
72:     CONTINUE_MESSAGE: Final = "Continue, and return a result JSON when finished."
73:     NUM_DEFAULT_MESSAGES: Final = 3  # Prompt + Assistant Initialization + User Task
74:     INITIALIZER_DUMMY: Final = "automata_initializer"
75:     ERROR_DUMMY_TOOL: Final = "error_reporter"
76: 
77:     def __init__(self, config: Optional[AutomataAgentConfig] = None):
78:         """
79:         Args:
80:             config (Optional[AutomataAgentConfig]): The agent config to use
81:         Methods:
82:             iter_task(instructions: List[Dict[str, str]]) -> Dict[str, str]: Iterates through the instructions and returns the next instruction.
83:             modify_last_instruction(new_instruction: str) -> None
84:             replay_messages() -> List[Dict[str, str]]: Replays agent messages buffer.
85:             run() -> str: Runs the agent.
86:             get_non_instruction_messages() -> List[Dict[str, str]]: Returns all messages that are not instructions.
87:         """
88:         if config is None:
89:             config = AutomataAgentConfig()
90:         self.instruction_payload = config.instruction_payload
91:         self.llm_toolkits = config.llm_toolkits
92:         self.instructions = config.instructions
93:         self.config_version = config.config_version
94:         self.system_instruction_template = config.system_instruction_template
95:         self.instruction_input_variables = config.instruction_input_variables
96:         self.model = config.model
97:         self.stream = config.stream
98:         self.verbose = config.verbose
99:         self.max_iters = config.max_iters
100:         self.temperature = config.temperature
101:         self.instruction_version = config.instruction_version
102:         self.completed = False
103:         self.eval_mode = False
104:         self.messages: List[OpenAIChatMessage] = []
105:         self.session_id = config.session_id
106:         self.conn: Optional[sqlite3.Connection] = None
107:         self.name: str = config.name
108: 
109:     def __del__(self):
110:         """Close the connection to the agent."""
111:         if self.conn:
112:             self.conn.close()
113: 
114:     def run(self) -> str:
115:         latest_responses = self.iter_task()
116:         while latest_responses is not None:
117:             # Each iteration adds two messages, one from the assistant and one from the user
118:             # If we have equal to or more than 2 * max_iters messages (less the default messages),
119:             # then we have exceeded the max_iters
120:             if len(self.messages) - AutomataAgent.NUM_DEFAULT_MESSAGES >= self.max_iters * 2:
121:                 return "Result was not found before iterations exceeded max limit."
122:             latest_responses = self.iter_task()
123:         return self.messages[-1].content
124: 
125:     def iter_task(self) -> Optional[Tuple[OpenAIChatMessage, OpenAIChatMessage]]:
126:         """Run the test and report the tool outputs back to the master."""
127:         if self.completed:
128:             raise ValueError("Cannot run an agent that has already completed.")
129:         response_summary = openai.ChatCompletion.create(
130:             model=self.model,
131:             messages=[ele.to_dict() for ele in self.messages],
132:             temperature=self.temperature,
133:             stream=self.stream,
134:         )
135:         response_text = (
136:             self._stream_message(response_summary)
137:             if self.stream
138:             else OpenAIChatCompletionResult(raw_data=response_summary).get_completion()
139:         )
140: 
141:         observations = self._generate_observations(response_text)
142: 
143:         completion_message = retrieve_completion_message(observations)
144:         if completion_message is not None:
145:             self.completed = True
146:             self._save_interaction(
147:                 "assistant",
148:                 self._parse_completion_message(completion_message)
149:                 if not self.eval_mode
150:                 else response_text,
151:             )
152:             return None
153: 
154:         assistant_message = self._save_interaction("assistant", response_text)
155:         user_message = self._save_interaction(
156:             "user",
157:             generate_user_observation_message(observations)
158:             if len(observations) > 0
159:             else AutomataAgent.CONTINUE_MESSAGE,
160:         )
161: 
162:         return (assistant_message, user_message)
163: 
164:     def replay_messages(self) -> str:
165:         """Replay the messages in the conversation."""
166:         if len(self.messages) == 0:
167:             logger.debug("No messages to replay.")
168:             return "No messages to replay."
169:         for message in self.messages[self.NUM_DEFAULT_MESSAGES :]:
170:             observations = self._generate_observations(message.content)
171:             completion_message = retrieve_completion_message(observations)
172:             if completion_message:
173:                 return completion_message
174:             logger.debug("Role:\n%s\n\nMessage:\n%s\n" % (message.role, message.content))
175:             logger.debug("Processing message content =  %s" % message.content)
176:             logger.debug("\nProcessed Outputs:\n%s\n" % observations)
177:             logger.debug("-" * 60)
178:         return "No completion message found."
179: 
180:     def modify_last_instruction(self, new_instruction: str) -> None:
181:         """Extend the last instructions with a new message."""
182:         previous_message = self.messages[-1]
183:         if previous_message.role != "user":
184:             raise ValueError("Cannot modify the last instruction if it was not a user message.")
185:         self.messages[-1] = OpenAIChatMessage(role=previous_message.role, content=new_instruction)
186: 
187:     def get_non_instruction_messages(self) -> List[OpenAIChatMessage]:
188:         """Get the non-instruction messages."""
189:         return self.messages[self.NUM_DEFAULT_MESSAGES :]
190: 
191:     def _setup(self):
192:         """Setup the agent."""
193:         openai.api_key = OPENAI_API_KEY
194:         if "tools" in self.instruction_input_variables:
195:             self.instruction_payload["tools"] = self._build_tool_message()
196:         system_instruction = format_prompt(
197:             self.instruction_payload, self.system_instruction_template
198:         )
199:         self._init_database()
200:         if self.session_id:
201:             self._load_previous_interactions()
202:         else:
203:             self.session_id = str(uuid.uuid4())
204:             self._save_interaction("system", system_instruction)
205:             initial_messages = self._build_initial_messages(
206:                 {"user_input_instructions": self.instructions}
207:             )
208:             for message in initial_messages:
209:                 self._save_interaction(message.role, message.content)
210:         logger.debug("Initializing with System Instruction:%s\n\n" % system_instruction)
211:         logger.debug("-" * 60)
212:         if set(self.instruction_input_variables) != set(list(self.instruction_payload.keys())):
213:             raise ValueError(f"Initial payload does not match instruction_input_variables.")
214:         logger.debug("Session ID: %s" % self.session_id)
215:         logger.debug("-" * 60)
216: 
217:     def _generate_observations(self, response_text: str) -> Dict[str, str]:
218:         """Process the messages in the conversation."""
219:         outputs = {}
220:         actions = ActionExtractor.extract_actions(response_text)
221:         for action in actions:
222:             if isinstance(action, ToolAction):
223:                 (tool_query, tool_name, tool_input) = (
224:                     action.tool_query,
225:                     action.tool_name,
226:                     action.tool_args,
227:                 )
228:                 # Skip the initializer dummy tool which exists only for providing context
229:                 if tool_name == AutomataAgent.INITIALIZER_DUMMY:
230:                     continue
231:                 if tool_name == AutomataAgent.ERROR_DUMMY_TOOL:
232:                     # Input becomes the output when an error is registered
233:                     outputs[tool_query.replace("query", "output")] = cast(str, tool_input)
234:                 else:
235:                     tool_output = self._execute_tool(tool_name, tool_input)
236:                     outputs[tool_query.replace("query", "output")] = tool_output
237:             elif isinstance(action, ResultAction):
238:                 (result_name, result_outputs) = (action.result_name, action.result_outputs)
239:                 # Skip the return result indicator which exists only for marking the return result
240:                 outputs[result_name] = "\n".join(result_outputs)
241: 
242:         return outputs
243: 
244:     def _execute_tool(self, tool_name: str, tool_input: List[str]) -> str:
245:         """Execute the tool with the given name and input."""
246:         tool_found = False
247:         tool_output = None
248: 
249:         for toolkit in self.llm_toolkits.values():
250:             for tool in toolkit.tools:
251:                 if tool.name == tool_name:
252:                     processed_tool_input = [ele if ele != "None" else None for ele in tool_input]
253:                     tool_output = tool.run(tuple(processed_tool_input), verbose=False)
254:                     tool_found = True
255:                     break
256:             if tool_found:
257:                 break
258: 
259:         if not tool_found:
260:             error_message = f"Error: Tool '{tool_name}' not found."
261:             return error_message
262: 
263:         return cast(str, tool_output)
264: 
265:     def _init_database(self):
266:         """Initialize the database connection."""
267:         self.conn = sqlite3.connect(CONVERSATION_DB_NAME)
268:         self.cursor = self.conn.cursor()
269:         self.cursor.execute(
270:             "\n            CREATE TABLE IF NOT EXISTS interactions (\n                session_id INTEGER,\n                interaction_id INTEGER,\n                role TEXT,\n                content TEXT,\n                PRIMARY KEY (session_id, interaction_id)\n            )\n            "
271:         )
272:         self.conn.commit()
273: 
274:     def _save_interaction(self, role: str, content: str) -> OpenAIChatMessage:
275:         """Save the interaction to the database."""
276:         assert self.session_id is not None, "Session ID is not set."
277:         assert self.conn is not None, "Database connection is not set."
278:         interaction = OpenAIChatMessage(role=role, content=content)
279:         interaction_id = len(self.messages)
280:         self.cursor.execute(
281:             "INSERT INTO interactions (session_id, interaction_id, role, content) VALUES (?, ?, ?, ?)",
282:             (self.session_id, interaction_id, role, content),
283:         )
284:         self.conn.commit()
285:         self.messages.append(interaction)
286:         return interaction
287: 
288:     def _load_previous_interactions(self):
289:         """Load the previous interactions from the database."""
290:         self.cursor.execute(
291:             "SELECT role, content FROM interactions WHERE session_id = ? ORDER BY interaction_id ASC",
292:             (self.session_id,),
293:         )
294:         self.messages = [
295:             OpenAIChatMessage(role=role, content=content)
296:             for (role, content) in self.cursor.fetchall()
297:         ]
298: 
299:     def _build_tool_message(self):
300:         """Builds a message containing all tools and their descriptions."""
301:         return "Tools:\n" + "".join(
302:             [
303:                 f"\n{tool.name}: {tool.description}\n"
304:                 for toolkit in self.llm_toolkits.values()
305:                 for tool in toolkit.tools
306:             ]
307:         )
308: 
309:     def _parse_completion_message(self, completion_message: str) -> str:
310:         """Parse the completion message and replace the tool outputs."""
*** 311:         outputs = {}
*** 312:         for message in self.messages:
*** 313:             pattern = r"-\s(tool_output_\d+)\s+-\s(.*?)(?=-\s(tool_output_\d+)|$)"
*** 314:             matches = re.finditer(pattern, message.content, re.DOTALL)
*** 315:             for match in matches:
*** 316:                 tool_name, tool_output = match.group(1), match.group(2).strip()
*** 317:                 outputs[tool_name] = tool_output
318: 
*** 319:         for output_name in outputs:
*** 320:             completion_message = completion_message.replace(
321:                 f"{{{output_name}}}", outputs[output_name]
322:             )
*** 323:         return completion_message
324: 
325:     def _build_initial_messages(self, formatters: Dict[str, str]) -> List[OpenAIChatMessage]:
326:         """Build the initial messages."""
327:         assert "user_input_instructions" in formatters
328:         formatters["initializer_dummy_tool"] = AutomataAgent.INITIALIZER_DUMMY
329: 
330:         messages_config = load_config(ConfigCategory.INSTRUCTION.value, self.instruction_version)
331:         initial_messages = messages_config["initial_messages"]
332: 
333:         input_messages = []
334:         for message in initial_messages:
335:             input_message = format_prompt(formatters, message["content"])
336:             input_messages.append(OpenAIChatMessage(role=message["role"], content=input_message))
337: 
338:         return input_messages
339: 
340:     def _stream_message(self, response_summary: Any):
341:         """Stream the response message."""
342:         print(colored(f"\n>>> {self.name} Agent:", "green"))
343:         latest_accumulation = ""
344:         stream_separator = " "
345:         response_text = ""
346:         for chunk in response_summary:
347:             if "content" in chunk["choices"][0]["delta"]:
348:                 chunk_content = chunk["choices"][0]["delta"]["content"]
349:                 chunk_content.replace("\\n", "\n")
350:                 latest_accumulation += chunk_content
351:                 response_text += chunk_content
352:             if stream_separator in latest_accumulation:
353:                 words = latest_accumulation.split(stream_separator)
354:                 for word in words[:-1]:
355:                     print(colored(str(word), "green"), end=" ", flush=True)
356:                 latest_accumulation = words[-1]
357:         print(colored(str(latest_accumulation), "green"))
358:         return response_text
359: 
360: 
361: ```

Remove unused variables from config and env files

There are a number of config environments that are not actually being used in the code, for example the pinecone ones but there are definitely others. Remove them from the config file and the env example

Hardcode repository name in config

Currently the code in main.py prompts the user to choose a git repository. In practice this hasn't been useful, so let's remove that and make the repository name a config variable. Things to do:

  1. Remove list_repositories from main and utils
  2. Add a config variable that's called repository name, make it work with dotenv like the others. Default it to maks-ivanov/improved-spork
  3. In the example config, add an example repository name.

Test coverage gap: tools.python_tools.python_indexer - retrieve_parent_function_num_code_lines

Write a test to satisfy the following coverage gap:
Module: tools.python_tools.python_indexer
Function: retrieve_parent_function_num_code_lines
Uncovered lines: [210, 211, 213, 214, 215, 216, 217, 218, 219]
Code:

49:     """
50:     A class to index Python source code files in a specified directory and retrieve code and docstrings.
51:     Attributes:
52:         abs_path (str): The absolute path to the root directory containing Python source code files to be indexed.
53:         module_dict (Dict[str, Module]): A dictionary with module paths as keys and AST Module objects as values.
54: 
55:     Methods:
56:         __init__(self, rel_path: str) -> None
57:         retrieve_code(self, module_path: str, object_path: Optional[str]) -> Optional[str]
58:         retrieve_docstring(self, module_path: str, object_path: Optional[str]) -> Optional[str]
59:     """
60: 
61:     NO_RESULT_FOUND_STR = "No Result Found."
62:     PATH_SEP = "."
63: 
64:     def __init__(self, rel_path: str) -> None:
65:         """
66:         Initializes the PythonIndexer with the specified root directory and builds the module dictionary.
67: 
68:         Args:
69:             rel_path (str): The root directory containing Python source code files to be indexed.
70:         """
71: 
72:         self.abs_path = os.path.join(root_path(), rel_path)
73: 
74:     @cached_property
75:     def module_dict(self) -> Dict[str, RedBaron]:
76:         # TODO: cache by module
77:         return self._build_module_dict()
78: 
79:     def retrieve_code_without_docstrings(
80:         self, module_path: str, object_path: Optional[str]
81:     ) -> str:
82:         """
83:         Retrieve code for a specified module, class, or function/method.
84: 
85:         Args:
86:             module_path (str): The path of the module in dot-separated format (e.g. 'package.module').
87:             object_path (Optional[str]): The path of the class, function, or method in dot-separated format
88:                 (e.g. 'ClassName.method_name'). If None, the entire module code will be returned.
89: 
90:         Returns:
91:             str: The code for the specified module, class, or function/method, or "No Result Found."
92:                 if not found.
93:         """
94: 
95:         if module_path not in self.module_dict:
96:             return PythonIndexer.NO_RESULT_FOUND_STR
97: 
98:         module = RedBaron(
99:             self.module_dict[module_path].dumps()
100:         )  # create a copy because we'll remove docstrings
101:         result = self.find_module_class_function_or_method(module, object_path)
102: 
103:         if result:
104:             PythonIndexer._remove_docstrings(result)
105:             return result.dumps()
106:         else:
107:             return PythonIndexer.NO_RESULT_FOUND_STR
108: 
109:     def retrieve_code(self, module_path: str, object_path: Optional[str] = None) -> str:
110:         """
111:         Retrieve code for a specified module, class, or function/method.
112: 
113:         Args:
114:             module_path (str): The path of the module in dot-separated format (e.g. 'package.module').
115:             object_path (Optional[str]): The path of the class, function, or method in dot-separated format
116:                 (e.g. 'ClassName.method_name'). If None, the entire module code will be returned.
117: 
118:         Returns:
119:             str: The code for the specified module, class, or function/method, or "No Result Found."
120:                 if not found.
121:         """
122: 
123:         if module_path not in self.module_dict:
124:             return PythonIndexer.NO_RESULT_FOUND_STR
125: 
126:         module = self.module_dict[module_path]
127:         result = self.find_module_class_function_or_method(module, object_path)
128: 
129:         if result:
130:             return result.dumps()
131:         else:
132:             return PythonIndexer.NO_RESULT_FOUND_STR
133: 
134:     def find_expression_context(
135:         self,
136:         expression: str,
137:         symmetric_width: int = 2,
138:     ) -> str:
139:         """
140:         Inspects the codebase for lines containing the expression and returns the line number and
141:         surrounding lines.
142: 
143:         Args:
144:             root_dir (str): The root directory to search.
145:             expression (str): The expression to search for.
146: 
147:         Returns:
148:             str: The context associated with the expression.
149:         """
150: 
151:         result = ""
152:         pattern = re.compile(expression)
153:         for module_path, module in self.module_dict.items():
154:             lines = module.dumps().splitlines()
155:             for i, line in enumerate(lines):
156:                 lineno = i + 1  # rebardon lines are 1 indexed, same as in an editor
157:                 if pattern.search(line):
158:                     lower_index = max(i - symmetric_width, 0)
159:                     upper_index = min(i + symmetric_width, len(lines))
160: 
161:                     raw_code = "\n".join(lines[lower_index : upper_index + 1])
162:                     result += f"{module_path}"
163: 
164:                     node = module.at(lineno)
165:                     if node.type not in ("def", "class"):
166:                         node = node.parent_find(lambda identifier: identifier in ("def", "class"))
167: 
168:                     if node:
169:                         result += f".{node.name}"
170: 
171:                     linespan_str = (
172:                         f"L{lineno}"
173:                         if not symmetric_width
174:                         else f"L{lower_index + 1}-{upper_index + 1}"
175:                     )
176:                     result += f"\n{linespan_str}\n```{raw_code}```\n\n"
177: 
178:         return result
179: 
180:     def retrieve_parent_function_name_by_line(self, module_path: str, line_number: int) -> str:
181:         """
182:         Retrieve code for a specified module, class, or function/method.
183: 
184:         Args:
185:             module_path (str): The path of the module in dot-separated format (e.g. 'package.module').
186:             line_number (int): The line number of the code to retrieve.
187: 
188:         Returns:
189:             str: The code for the specified module, class, or function/method, or "No Result Found."
190:                 if not found.
191:         """
192: 
193:         if module_path not in self.module_dict:
194:             return PythonIndexer.NO_RESULT_FOUND_STR
195: 
196:         node = self.module_dict[module_path].at(line_number)
197:         if node.type != "def":
198:             node = node.parent_find("def")
199:         if node:
200:             if node.parent[0].type == "class":
201:                 return f"{node.parent.name}.{node.name}"
202:             else:
203:                 return node.name
204:         else:
205:             return PythonIndexer.NO_RESULT_FOUND_STR
206: 
207:     def retrieve_parent_function_num_code_lines(
208:         self, module_path: str, line_number: int
209:     ) -> Union[int, str]:
*** 210:         if module_path not in self.module_dict:
*** 211:             return PythonIndexer.NO_RESULT_FOUND_STR
212: 
*** 213:         node = self.module_dict[module_path].at(line_number)
*** 214:         if node.type != "def":
*** 215:             node = node.parent_find("def")
*** 216:         if not node:
*** 217:             return PythonIndexer.NO_RESULT_FOUND_STR
*** 218:         filtered_code_lines = self._filter_code_lines(node)
*** 219:         return len(filtered_code_lines)
220: 
221:     def retrieve_parent_code_by_line(
222:         self, module_path: str, line_number: int, return_numbered=False
223:     ) -> str:
224:         """
225:         Retrieve code for a specified module, class, or function/method.
226: 
227:         Args:
228:             module_path (str): The path of the module in dot-separated format (e.g. 'package.module').
229:             line_number (int): The line number of the code to retrieve.
230:             return_numbered (bool): Whether to return the code with line numbers prepended.
231: 
232:         Returns:
233:             str: The code for the specified module, class, or function/method, or "No Result Found."
234:                 if not found.
235:         """
236: 
237:         if module_path not in self.module_dict:
238:             return PythonIndexer.NO_RESULT_FOUND_STR
239: 
240:         node = self.module_dict[module_path].at(line_number)
241:         while node.parent_find(lambda identifier: identifier in ("def", "class")):
242:             node = node.parent_find(lambda identifier: identifier in ("def", "class"))
243:         source_code = node.dumps()
244:         if return_numbered:
245:             start = node.absolute_bounding_box.top_left.line
246:             numbered_source_code = "\n".join(
247:                 [f"{i+start}: {line}" for i, line in enumerate(source_code.split("\n"))]
248:             )
249:             return numbered_source_code
250:         else:
251:             return source_code
252: 
253:     def retrieve_docstring(self, module_path: str, object_path: Optional[str]) -> str:
254:         """
255:         Retrieve the docstring for a specified module, class, or function/method.
256: 
257:         Args:
258:             module_path (str): The path of the module in dot-separated format (e.g. 'package.module').
259:             object_path (Optional[str]): The path of the class, function, or method in dot-separated format
260:                 (e.g. 'ClassName.method_name'). If None, the module-level docstring will be returned.
261: 
262:         Returns:
263:             str: The docstring for the specified module, class, or function/method, or "No Result Found."
264:                 if not found.
265:         """
266: 
267:         if module_path not in self.module_dict:
268:             return PythonIndexer.NO_RESULT_FOUND_STR
269: 
270:         module = self.module_dict[module_path]
271:         result = self.find_module_class_function_or_method(module, object_path)
272: 
273:         if result:
274:             return PythonIndexer._get_docstring(result) or PythonIndexer.NO_RESULT_FOUND_STR
275:         else:
276:             return PythonIndexer.NO_RESULT_FOUND_STR
277: 
278:     @staticmethod
279:     def _get_docstring(node) -> str:
280:         if isinstance(node, (ClassNode, DefNode, RedBaron)):
281:             filtered_nodes = node.filtered()  # get rid of extra whitespace
282:             if isinstance(filtered_nodes[0], StringNode):
283:                 return filtered_nodes[0].value.replace('"""', "").replace("'''", "")
284:         return ""
285: 
286:     def _build_module_dict(self) -> Dict[str, RedBaron]:
287:         """
288:         Builds the module dictionary by walking through the root directory and creating FST Module objects
289:         for each Python source file. The module paths are used as keys in the dictionary.
290: 
291:         Returns:
292:             Dict[str, RedBaron]: A dictionary with module paths as keys and RedBaron objects as values.
293:         """
294: 
295:         module_dict = {}
296: 
297:         for root, _, files in os.walk(self.abs_path):
298:             for file in files:
299:                 if file.endswith(".py"):
300:                     module_path = os.path.join(root, file)
301:                     module = self._load_module_from_path(module_path)
302:                     if module:
303:                         module_rel_path = PythonIndexer._relative_module_path(
304:                             self.abs_path, module_path
305:                         )
306:                         module_dict[module_rel_path] = module
307:         return module_dict
308: 
309:     @staticmethod
310:     def _relative_module_path(root_abs_path, module_path):
311:         module_rel_path = (os.path.relpath(module_path, root_abs_path).replace(os.path.sep, "."))[
312:             :-3
313:         ]
314:         return module_rel_path
315: 
316:     def get_module_path(self, module_obj: RedBaron) -> str:
317:         """
318:         Returns the module path for the specified module object.
319: 
320:         Args:
321:             module_obj (Module): The module object.
322: 
323:         Returns:
324:             str: The module path for the specified module object.
325:         """
326: 
327:         for module_path, module in self.module_dict.items():
328:             if module is module_obj:
329:                 return module_path
330:         return PythonIndexer.NO_RESULT_FOUND_STR
331: 
332:     @staticmethod
333:     def build_overview(path) -> str:
334:         """
335:         Loops over the directory python files and returns a string that provides an overview of the PythonParser's state.
336:         Returns:
337:             str: A string that provides an overview of the PythonParser's state.
338:         **NOTE: This method uses AST, not RedBaron, because RedBaron initialization is slow and unnecessary for this method.
339:         """
340:         result_lines = []
341: 
342:         for root, _, files in os.walk(path):
343:             for file in files:
344:                 if file.endswith(".py"):
345:                     module_path = os.path.join(root, file)
346:                     module = ast.parse(open(module_path).read())
347:                     relative_module_path = PythonIndexer._relative_module_path(path, module_path)
348:                     result_lines.append(relative_module_path)
349:                     PythonIndexer._overview_traverse_helper(module, result_lines)
350:         return "\n".join(result_lines)
351: 
352:     @staticmethod
353:     def _overview_traverse_helper(node, line_items, num_spaces=1):
354:         if isinstance(node, ClassDef):
355:             line_items.append("  " * num_spaces + " - cls " + node.name)
356:         elif isinstance(node, FunctionDef) or isinstance(node, AsyncFunctionDef):
357:             line_items.append("  " * num_spaces + " - func " + node.name)
358: 
359:         for child in ast.iter_child_nodes(node):
360:             PythonIndexer._overview_traverse_helper(child, line_items, num_spaces + 1)
361: 
362:     @staticmethod
363:     def _load_module_from_path(path) -> Optional[RedBaron]:
364:         """
365:         Loads and returns an FST object for the given file path.
366: 
367:         Args:
368:             path (str): The file path of the Python source code.
369: 
370:         Returns:
371:             Module: RedBaron FST object.
372:         """
373: 
374:         try:
375:             module = RedBaron(open(path).read())
376:             return module
377:         except Exception as e:
378:             logger.error(f"Failed to load module '{path}' due to: {e}")
379:             return None
380: 
381:     @staticmethod
382:     def find_module_class_function_or_method(
383:         code_obj: Union[RedBaron, ClassNode], object_path: Optional[str]
384:     ) -> Optional[Union[Node, RedBaron]]:
385:         """
386:         Find a module, or find a function, method, or class inside a module.
387: 
388:         Args:
389:             code_obj (RedBaron): The  red baron FST object.
390:             object_path (Optional[str]): The dot-separated object path (e.g., 'ClassName.method_name'). If None,
391:                 the module is returned.
392: 
393:         Returns:
394:             Optional[Union[Def, Class, Module]]: The found def, or class node, or None if not found.
395:         """
396: 
397:         if not object_path:
398:             return code_obj
399: 
400:         obj_parts = object_path.split(PythonIndexer.PATH_SEP)
401: 
402:         node = code_obj
403:         while node and obj_parts:
404:             obj_name = obj_parts.pop(0)
405:             node = PythonIndexer._find_node(node, obj_name)
406:         return node
407: 
408:     @staticmethod
409:     def _find_node(code_obj: RedBaron, obj_name: str) -> Optional[Union[DefNode, ClassNode]]:
410:         """
411:         Find a DefNode or ClassNode node with the specified name within the given
412:         FST code object.
413: 
414:         Args:
415:             code_obj (RedBaron): The FST code object (RedBaron or Node) to search.
416:             obj_name (str): The name of the object to find.
417: 
418:         Returns:
419:             Optional[Union[DefNode, ClassNode]]: The found node, or None.
420:         """
421:         return code_obj.find(lambda identifier: identifier in ("def", "class"), name=obj_name)
422: 
423:     @staticmethod
424:     def find_imports(module: RedBaron) -> Optional[NodeList]:
425:         """
426:         Find all imports in a module.
427: 
428:         Args:
429:             module (RedBaron): The module to search.
430: 
431:         Returns:
432:             Optional[NodeList]: A list of ImportNode and FromImportNode objects.
433:         """
434:         return module.find_all(lambda identifier: identifier in ("import", "from_import"))
435: 
436:     @staticmethod
437:     def find_import_by_name(
438:         module: RedBaron, import_name: str
439:     ) -> Optional[Union[ImportNode, FromImportNode]]:
440:         """
441:         Find an import by name.
442: 
443:         Args:
444:             module (RedBaron): The module to search.
445:             import_name (str): The name of the import to find.
446: 
447:         Returns:
448:             Optional[Union[ImportNode, FromImportNode]]: The found import, or None if not found.
449:         """
450:         return module.find(
451:             lambda identifier: identifier in ("import", "from_import"), name=import_name
452:         )
453: 
454:     @staticmethod
455:     def find_all_functions_and_classes(module: RedBaron) -> NodeList:
456:         """
457:         Find all imports in a module.
458: 
459:         Args:
460:             module (RedBaron): The module to search.
461: 
462:         Returns:
463:             NodeList: A list of ClassNode and DefNode objects.
464:         """
465:         return module.find_all(lambda identifier: identifier in ("class", "def"))
466: 
467:     @staticmethod
468:     def _remove_docstrings(node: Union[Node, RedBaron]) -> None:
469:         """
470:         Remove docstrings from the specified node, recursively.
471: 
472:         Args:
473:             node: The FST node
474:                 to remove docstrings from.
475:         """
476: 
477:         if isinstance(node, (DefNode, ClassNode, RedBaron)):
478:             filtered_node = node.filtered()
479:             if isinstance(filtered_node[0], StringNode):
480:                 index = filtered_node[0].index_on_parent
481:                 node.pop(index)
482:             child_nodes = node.find_all(lambda identifier: identifier in ("def", "class"))
483:             for child_node in child_nodes:
484:                 if child_node is not node:
485:                     PythonIndexer._remove_docstrings(child_node)
486: 
487:     def _filter_code_lines(self, node: Union[Node, RedBaron]) -> List[str]:
488:         """
489:         Returns lines of code that are not empty or comments or docstrings.
490: 
491:         Args:
492:             node: The FST node to filter.
493:         """
494:         body = node.value.copy()
495:         body = body.filter(lambda x: x.type != "string" or not x.value.startswith('"""'))
496:         source_code_lines = body.dumps().splitlines()
497:         predicate = (
498:             lambda l: not l.strip().startswith("#")
499:             and not l.strip().startswith("@")
500:             and not l.strip() == ""
501:         )
502:         source_code_lines = [line for line in source_code_lines if predicate(line)]
503:         return source_code_lines
504: ```

Implement graceful failures

Currently the agent fails relatively often, and it leaves a mess of git branches behind. We need to add functionality to detect and clean up failed chains, basically undoing the steps from before.

Fix the cosine similarity function

There's a function in utils file that computes cosine similarity between two vectors but it doesn't normalize the result properly. Fix it

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.