fork download
  1. #include <iostream>
  2. #include <string>
  3. using namespace std;
  4.  
  5. // Define node types for the expression tree
  6. enum NodeType { CONST, VAR, ADD, SUB, MUL, POW };
  7.  
  8. // Basic Node structure
  9. struct Node {
  10. NodeType type;
  11. double val; // Used if type is CONST
  12. string varName; // Used if type is VAR
  13. Node *left, *right;
  14.  
  15. Node(NodeType t) : type(t), val(0), left(nullptr), right(nullptr) {}
  16. };
  17.  
  18. // --- HELPER FUNCTIONS ---
  19.  
  20. // Create a constant node
  21. Node* createConst(double v) {
  22. Node* n = new Node(CONST);
  23. n->val = v;
  24. return n;
  25. }
  26.  
  27. // Create a variable node
  28. Node* createVar(string name) {
  29. Node* n = new Node(VAR);
  30. n->varName = name;
  31. return n;
  32. }
  33.  
  34. // Create an operator node
  35. Node* createOp(NodeType t, Node* l, Node* r) {
  36. Node* n = new Node(t);
  37. n->left = l;
  38. n->right = r;
  39. return n;
  40. }
  41.  
  42. // DEEP COPY: Recursively clones a tree to prevent memory sharing
  43. Node* copyTree(Node* root) {
  44. if (!root) return nullptr;
  45. Node* newNode = new Node(root->type);
  46. newNode->val = root->val;
  47. newNode->varName = root->varName;
  48. newNode->left = copyTree(root->left);
  49. newNode->right = copyTree(root->right);
  50. return newNode;
  51. }
  52.  
  53. // Recursive function to free memory
  54. void deleteTree(Node* root) {
  55. if (!root) return;
  56. deleteTree(root->left);
  57. deleteTree(root->right);
  58. delete root;
  59. }
  60.  
  61. // Print the expression in human-readable format
  62. void printTree(Node* root) {
  63. if (!root) return;
  64. if (root->type == CONST) cout << root->val;
  65. else if (root->type == VAR) cout << root->varName;
  66. else {
  67. cout << "(";
  68. printTree(root->left);
  69. if (root->type == ADD) cout << " + ";
  70. else if (root->type == SUB) cout << " - ";
  71. else if (root->type == MUL) cout << " * ";
  72. else if (root->type == POW) cout << " ^ ";
  73. printTree(root->right);
  74. cout << ")";
  75. }
  76. }
  77.  
  78. // --- DIFFERENTIATION ALGORITHM ---
  79. Node* derive(Node* n, string var) {
  80. switch (n->type) {
  81. case CONST:
  82. return createConst(0); // (C)' = 0
  83. case VAR:
  84. return createConst(n->varName == var ? 1 : 0); // (x)' = 1, (y)' = 0
  85. case ADD:
  86. return createOp(ADD, derive(n->left, var), derive(n->right, var));
  87. case SUB:
  88. return createOp(SUB, derive(n->left, var), derive(n->right, var));
  89. case MUL: {
  90. // Product Rule: (u*v)' = u'v + uv'
  91. // We use copyTree to ensure the new tree has its own nodes
  92. Node* leftPart = createOp(MUL, derive(n->left, var), copyTree(n->right));
  93. Node* rightPart = createOp(MUL, copyTree(n->left), derive(n->right, var));
  94. return createOp(ADD, leftPart, rightPart);
  95. }
  96. case POW: {
  97. // Power Rule: (u^n)' = n * u^(n-1) * u'
  98. // Assumes n is a constant
  99. double nVal = n->right->val;
  100. Node* nMinus1 = createConst(nVal - 1);
  101. Node* newPow = createOp(POW, copyTree(n->left), nMinus1);
  102. Node* step1 = createOp(MUL, createConst(nVal), newPow);
  103. return createOp(MUL, step1, derive(n->left, var));
  104. }
  105. }
  106. return nullptr;
  107. }
  108.  
  109. // --- SIMPLIFICATION ALGORITHM ---
  110. Node* simplify(Node* n) {
  111. if (!n || n->type == CONST || n->type == VAR) return n;
  112. // Simplify children first (Post-order traversal)
  113. n->left = simplify(n->left);
  114. n->right = simplify(n->right);
  115. // Simplify ADD (+)
  116. if (n->type == ADD) {
  117. if (n->left->type == CONST && n->left->val == 0) return n->right; // 0 + x = x
  118. if (n->right->type == CONST && n->right->val == 0) return n->left; // x + 0 = x
  119. }
  120. // Simplify MUL (*)
  121. else if (n->type == MUL) {
  122. if (n->left->type == CONST && n->left->val == 0) return createConst(0); // 0 * x = 0
  123. if (n->right->type == CONST && n->right->val == 0) return createConst(0); // x * 0 = 0
  124. if (n->left->type == CONST && n->left->val == 1) return n->right; // 1 * x = x
  125. if (n->right->type == CONST && n->right->val == 1) return n->left; // x * 1 = x
  126. }
  127. // Simplify POW (^)
  128. else if (n->type == POW) {
  129. if (n->right->type == CONST && n->right->val == 1) return n->left; // x ^ 1 = x
  130. if (n->right->type == CONST && n->right->val == 0) return createConst(1); // x ^ 0 = 1
  131. }
  132. // Constant Folding: if both sides are numbers, calculate immediately
  133. if (n->left->type == CONST && n->right->type == CONST) {
  134. if (n->type == ADD) return createConst(n->left->val + n->right->val);
  135. if (n->type == SUB) return createConst(n->left->val - n->right->val);
  136. if (n->type == MUL) return createConst(n->left->val * n->right->val);
  137. }
  138. return n;
  139. }
  140.  
  141. // --- MAIN FUNCTION ---
  142. int main() {
  143. // f(x) = x^2 + 3x
  144. Node* x = createVar("x");
  145. Node* expr = createOp(ADD,
  146. createOp(POW, x, createConst(2)),
  147. createOp(MUL, createConst(3), copyTree(x)));
  148.  
  149. cout << "Original Expression: "; printTree(expr); cout << endl;
  150. // Calculate Derivative
  151. Node* d = derive(expr, "x");
  152. cout << "Raw Derivative: "; printTree(d); cout << endl;
  153. // Simplify the derivative
  154. // We run it twice to ensure nested simplifications (like 0 + (3 * 1)) are fully resolved
  155. d = simplify(d);
  156. d = simplify(d);
  157. cout << "Simplified Result: "; printTree(d); cout << endl;
  158. // Clean up memory
  159. deleteTree(expr);
  160. deleteTree(d);
  161.  
  162. return 0;
  163. }
Success #stdin #stdout 0.01s 5308KB
stdin
Standard input is empty
stdout
Original Expression: ((x ^ 2) + (3 * x))
Raw Derivative:      (((2 * (x ^ 1)) * 1) + ((0 * x) + (3 * 1)))
Simplified Result:   ((2 * x) + 3)