summaryrefslogtreecommitdiff
path: root/python_agent/to_ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'python_agent/to_ast.py')
-rw-r--r--python_agent/to_ast.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/python_agent/to_ast.py b/python_agent/to_ast.py
new file mode 100644
index 0000000..3f5aadb
--- /dev/null
+++ b/python_agent/to_ast.py
@@ -0,0 +1,63 @@
+import hashlib
+from os.path import exists
+from pathlib import Path
+from ast import parse, unparse, FunctionDef, ClassDef
+
+def rglob(directory, extension):
+ """
+ Recursively glob all files with the given extension in the specified directory.
+
+ Args:
+ directory (str or Path): The root directory to start the search.
+ extension (str): The file extension (e.g., '.txt', '.py').
+ The leading dot is optional but good practice
+
+ Returns:
+ list: A list of Path objects for the matching files
+ """
+ # Ensure the extension starts with a dot if not already provided
+ if not extension.startswith("."):
+ extension = "." + extension
+
+ # Use rglob to recursively find files matching the pattern
+ # The pattern should be the extension itself as rglob operates recursively
+ files = list(Path(directory).rglob(f"*{extension}"))
+ return files
+
+for file in rglob("src/", "py"):
+ source = ""
+ with open(file, "r") as fhandle:
+ source = fhandle.read()
+ tree = parse(source)
+ nodes = [node for node in tree.body if isinstance(node, (FunctionDef, ClassDef))]
+ for node in nodes:
+ src = unparse(node)
+ src = f"# Function derived from {file}\n" + src
+ if not src.endswith("\n"):
+ src = src + "\n"
+ fout = ""
+ while True:
+ srce = src.encode('utf-8')
+ src_hobj = hashlib.sha256(srce)
+ src_dig = src_hobj.hexdigest()
+ fout = f"agent/ast/{src_dig}.py"
+ if exists(fout):
+ # Check hash to see if it is the same
+ old_src = ""
+ with open(fout, "r") as fhandle:
+ old_src = fhandle.read()
+ old_srce = old_src.encode('utf-8')
+ osrc_hobj = hashlib.sha256(old_srce)
+ osrc_dig = osrc_hobj.hexdigest()
+ if osrc_dig == src_dig:
+ break
+ # Add something to clear up collision
+ src = src + "# Added due to hash collision\n"
+ print("Collision detected!")
+ continue
+ break
+ if exists(fout):
+ continue
+ with open(fout, "w+") as fhandle:
+ fhandle.write(src)
+